Fine-tune a LLM using TPUs on GKE with JAX

This tutorial shows you how to fine-tune a large language model (LLM) using Tensor Processing Units (TPUs) on Google Kubernetes Engine (GKE) with JAX. Fine-tuning lets you adapt a foundation model like Gemma 3 to a specific domain or task. This process improves the precision and accuracy of the model by updating its parameters with your own specialized dataset.

This guide is a good starting point if you need the granular control, customization, scalability, resilience, portability, and cost-effectiveness of managed Kubernetes when fine-tuning your AI/ML workloads.

Background

By using TPUs on GKE with Jax to fine-tune an LLM, you can build a robust, production-ready fine-tuning solution with all the benefits of managed Kubernetes.

Gemma

Gemma is a set of openly available, lightweight, generative AI/ML multimodal models released under an open license. These AI models are available to run in your applications, hardware, mobile devices, or hosted services. Gemma 3 introduces multimodality, and it supports vision-language input and text outputs. It handles context windows of up to 128,000 tokens and supports over 140 languages. Gemma 3 also offers improved math, reasoning, and chat capabilities, including structured outputs and function calling.

You can use the Gemma models for text generation, or you can also tune these models for specialized tasks.

For more information, see the Gemma documentation.

TPUs

TPUs are application-specific integrated circuits (ASICs) that Google custom-developed to accelerate machine learning and AI models that are built using frameworks such as TensorFlow, PyTorch, and JAX.

Before you use TPUs in GKE, we recommend that you complete the following learning path:

  1. Learn about current TPU version availability with the Cloud TPU system architecture.
  2. Learn about TPUs in GKE.

JAX

JAX is a high-performance machine learning framework that is designed to be used with TPUs and GPUs. JAX provides an API for building and training machine learning models.

To learn more, see the JAX repository.

Objectives

This tutorial covers the following steps:

  1. Create a GKE Autopilot or Standard cluster with the recommended TPU topology, based on the model characteristics. During this tutorial, you perform the fine-tuning on single-host node pools.
  2. Add data to a Cloud Storage bucket and mount it to the container through Cloud Storage FUSE.
  3. Deploy the LLM fine-tuning Job on GKE.
  4. Monitor the fine-tuning Job and view the logs.

Before you begin

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • Make sure that you have the following role or roles on the project: roles/container.admin,roles/iam.serviceAccountAdmin,roles/storage.admin

    Check for the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. In the Principal column, find all rows that identify you or a group that you're included in. To learn which groups you're included in, contact your administrator.

    4. For all rows that specify or include you, check the Role column to see whether the list of roles includes the required roles.

    Grant the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. Click Grant access.
    4. In the New principals field, enter your user identifier. This is typically the email address for a Google Account.

    5. Click Select a role, then search for the role.
    6. To grant additional roles, click Add another role and add each additional role.
    7. Click Save.
  • Ensure that you have sufficient quota for 16 TPU Trillium (v6e) chips. In this tutorial, you use a node pool configuration that requires 16 chips and on-demand instances.
  • Ensure that you have a Docker repository. If you don't have one, create a standard repository in Artifact Registry.

Prepare the environment

In this tutorial, you use Cloud Shell to manage resources hosted on Google Cloud. Cloud Shell comes preinstalled with the software you need for this tutorial, including kubectl and Google Cloud CLI.

To set up your environment with Cloud Shell, follow these steps:

  1. In the Google Cloud console, launch a Cloud Shell session and click Cloud Shell activation icon Activate Cloud Shell. This action launches a session in the bottom pane of the Google Cloud console.

  2. Set the default environment variables:

    gcloud config set project PROJECT_ID
    gcloud config set billing/quota_project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=CONTROL_PLANE_LOCATION
    export ZONE=ZONE
    export GCS_BUCKET_NAME=BUCKET_NAME
    

    Replace the following values:

    • PROJECT_ID: your Google Cloud project ID.
    • CLUSTER_NAME: the name of your GKE cluster.
    • CONTROL_PLANE_LOCATION: the Compute Engine region where your GKE cluster and TPU nodes are located. The region must contain zones where TPU Trillium (v6e) machine types are available.
    • ZONE: a zone within your selected CONTROL_PLANE_LOCATION region where TPU Trillium (v6e) machine types are available. To list zones where TPU Trillium (v6e) TPUs are available, run the following command:

        gcloud compute accelerator-types list --filter="name~ct6e" --format="value(zone)"
      
    • BUCKET_NAME: the name of the Cloud Storage bucket that contains your training data.

  3. Clone the sample repository:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    cd kubernetes-engine-samples
    
  4. Navigate to the working directory:

    cd ai-ml/llm-training-jax-tpu-gemma3
    

Create and configure Google Cloud resources

In this section, you create and configure Google Cloud resources.

Create a GKE cluster

You can fine-tune an LLM on TPUs in a GKE Autopilot or Standard cluster. We recommend that you use a Autopilot cluster for a fully managed Kubernetes experience. To choose the GKE mode of operation that's the best fit for your workloads, see Choose a GKE mode of operation.

Autopilot

Create a GKE Autopilot cluster that uses Workload Identity Federation for GKE and has Cloud Storage FUSE enabled.

gcloud container clusters create-auto ${CLUSTER_NAME} \
    --location=${REGION}

The cluster creation might take several minutes.

Standard

  1. Create a regional GKE Standard cluster that uses Workload Identity Federation for GKE and has Cloud Storage FUSE enabled.

    gcloud container clusters create ${CLUSTER_NAME} \
        --enable-ip-alias \
        --addons GcsFuseCsiDriver \
        --machine-type=n2-standard-4 \
        --num-nodes=2 \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --location=${REGION}
    

    The cluster creation might take several minutes.

  2. Create a single-host node pool:

    gcloud container node-pools create jax-tpu-nodepool \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-1t \
        --num-nodes=1 \
        --location=${REGION} \
        --node-locations=${ZONE} \
        --workload-metadata=GKE_METADATA
    

GKE creates a TPU Trillium node pool with a 1x1 topology and one node. The --workload-metadata=GKE_METADATA flag configures the node pool to use the GKE metadata server.

Install JobSet

  1. Configure kubectl to communicate with your cluster:

    gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}
    
  2. Install the latest released version of JobSet:

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    Replace JOBSET_VERSION with the latest released version of JobSet. For example, v0.11.0.

  3. Verify the JobSet installation:

    kubectl get pods -n jobset-system
    

    The output is similar to the following:

    NAME                                         READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-6c56668494-l4dhc   1/1     Running   0          4m45s
    

    You might need to add more nodes if JobSet is waiting for resources.

Configure Cloud Storage FUSE

To fine-tune the LLM, you need to provide training data. In this tutorial, you use the TinyStories dataset from Hugging Face. This dataset contains short stories, synthetically generated by GPT-3.5 and GPT-4, that use a limited vocabulary.

This section covers the steps to configure Cloud Storage FUSE to read data from a Cloud Storage bucket.

  1. Download the dataset:

    wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
    
  2. Upload the data into a new Cloud Storage bucket:

    gcloud storage buckets create gs://${GCS_BUCKET_NAME} \
        --location=${REGION} \
        --enable-hierarchical-namespace \
        --uniform-bucket-level-access
    gcloud storage cp TinyStories-train.txt gs://${GCS_BUCKET_NAME}
    
  3. To allow your workload to read data through Cloud Storage FUSE, create a Kubernetes service account (KSA) and add the required permissions. Run the permissionsetup.sh script:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    #!/bin/bash
    
    # --- Configuration Variables ---
    # Kubernetes Service Account details
    export KSA_NAME="jaxserviceaccout"
    export NAMESPACE="default"
    
    # Google Cloud IAM Service Account details
    export GSA_NAME="<GSA_NAME>"
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    export  GSA_DESCRIPTION="GKE Service Account to read GCS bucket for ${KSA_NAME}"
    
    # GCS Bucket details
    export GCS_BUCKET_NAME="<GCS_BUCKET_NAME>" # <--- IMPORTANT: Update this to your bucket name
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    export WI_MEMBER="serviceAccount:${PROJECT_ID}.svc.id.goog[${NAMESPACE}/${KSA_NAME}]"
    
    # --- Check if PROJECT_ID is set ---
    if [ -z "${PROJECT_ID}" ]; then
      echo "Error: PROJECT_ID is not set. Please set it using 'gcloud config set project YOUR_PROJECT_ID'"
      exit 1
    fi
    
    echo "--- Configuration ---"
    echo "KSA_NAME:      ${KSA_NAME}"
    echo "NAMESPACE:     ${NAMESPACE}"
    echo "GSA_NAME:      ${GSA_NAME}"
    echo "PROJECT_ID:    ${PROJECT_ID}"
    echo "GSA_EMAIL:     ${GSA_EMAIL}"
    echo "GCS_BUCKET_NAME:   ${GCS_BUCKET_NAME}"
    echo "WI_MEMBER:     ${WI_MEMBER}"
    echo "--------------------"
    read -p "Press enter to continue..."
    
    # --- Command Execution ---
    
    echo "[1/5] Creating Google Cloud IAM Service Account (GSA): ${GSA_NAME}"
    gcloud iam service-accounts create "${GSA_NAME}" \
        --project="${PROJECT_ID}" \
        --description="${GSA_DESCRIPTION}" \
        --display-name="${GSA_NAME}"
    
    echo "[2/5] Granting GSA '${GSA_EMAIL}' read access (roles/storage.objectViewer) to bucket 'gs://${GCS_BUCKET_NAME}'"
    gcloud storage buckets add-iam-policy-binding "gs://${GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectViewer" \
        --project="${PROJECT_ID}"
    
    echo "[3/5] Creating Kubernetes Service Account (KSA): ${KSA_NAME} in namespace ${NAMESPACE}"
    kubectl create serviceaccount "${KSA_NAME}" --namespace "${NAMESPACE}"
    
    echo "[4/5] Allowing KSA to impersonate GSA (Workload Identity Binding): ${GSA_EMAIL}"
    gcloud iam service-accounts add-iam-policy-binding "${GSA_EMAIL}" \
        --role roles/iam.workloadIdentityUser \
        --member "${WI_MEMBER}" \
        --project="${PROJECT_ID}"
    
    echo "[5/5] Annotating KSA '${KSA_NAME}' to link with GSA '${GSA_EMAIL}'"
    kubectl annotate serviceaccount "${KSA_NAME}" \
        --namespace "${NAMESPACE}" \
        iam.gke.io/gcp-service-account="${GSA_EMAIL}"
    
    echo "--- Setup Complete ---"
    echo "Pods in namespace '${NAMESPACE}' using serviceAccount '${KSA_NAME}' can now authenticate as '${GSA_EMAIL}' and have read access to 'gs://${GCS_BUCKET_NAME}'."
    

    After you run this script, the following resources are configured in your Google Cloud project and GKE cluster:

    • A new IAM service account named gcs-fuse-sa is created in your project.
    • The created Google Cloud Service Account (GSA) (gcs-fuse-sa) is granted the roles/storage.objectViewer role on the Cloud Storage bucket specified by ${GCS_BUCKET_NAME}. This permission allows the GSA to read objects from the bucket.
    • A new KSA named jaxserviceaccount is created in the default namespace within your GKE cluster.
    • The IAM policy of the GSA is updated to grant the roles/iam.workloadIdentityUser role to the KSA. This permission allows the KSA to impersonate the GSA.
    • The KSA is annotated to link it to the GSA. This annotation tells GKE which GSA the KSA should impersonate by using Workload Identity.

      Any Pod running in the default namespace of your GKE cluster that uses the jaxserviceaccount service account will now be able to authenticate as the gcs-fuse-sa GSA. These Pods will have read access to the objects stored in the gs://${GCS_BUCKET_NAME} bucket, which is essential for the fine-tuning Job to access the dataset by using Cloud Storage FUSE.

Create the fine-tuning script

In this section, you explore the training script that performs a fine-tuning operation on a Gemma 3 model. This script uses the Gemma3Tokenizer.

Review the following Gemma3LLMTrain.py fine-tuning script:

# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import grain.python as pygrain
import jax
import jax.numpy as jnp
import optax
import pandas as pd
import time
import argparse

from dataclasses import dataclass
from functools import partial
from gemma import gm
from flax.training import train_state
from jax.sharding import Mesh, PartitionSpec, NamedSharding

jax.distributed.initialize()
print("Global device count:", jax.device_count())
print("jax version:", jax.__version__)

tokenizer = gm.text.Gemma3Tokenizer()
num_epochs = 1
learning_rate = 2e-5

@dataclass
class TextDataset:
    data: list
    maxlen: int
    tokenizer: gm.text.Gemma3Tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        encoding = self.tokenizer.encode(self.data[idx])[:self.maxlen]  # Tokenize and truncate
        return encoding + [0] * (self.maxlen - len(encoding))  # Pad to maxlen

def load_and_preprocess_data(file_path, batch_size, maxlen, datacount, tokenizer):

    with open(file_path, 'r') as f:
      text = f.read()

    stories = text.split('<|endoftext|>')
    stories = [story for story in stories if story.strip()][:datacount]
    df = pd.DataFrame({'text': stories})
    data = df['text'].dropna().tolist()
    dataset = TextDataset(data, maxlen, tokenizer)

    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=False,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )

    dataloader = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )

    return dataloader

def generate_text(model, params, tokenizer, prompt):
    sampler = gm.text.Sampler(
        model=model,
        params=params,
        tokenizer=tokenizer,
    )
    print("Generating response for: " + prompt)
    out = sampler.sample(prompt, max_new_tokens=32)
    print("Reponse: \n" + out + "\n")
    return out

prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))

@partial(jax.jit, donate_argnums=(0,))
def train_step(state, batch):
    """Performs one supervised fine-tuning step."""

    def loss_fn(params):
        # Run the forward pass. The model returns logits.
        logits = state.apply_fn({'params': params}, batch[0]).logits

        # Calculate the cross-entropy loss.
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch[1]
        ).mean()

        return loss

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)

    # Update the model state
    state = state.apply_gradients(grads=grads)

    metrics = {'loss': loss}
    return state, metrics

def train_model(state, text_dl, num_epochs, sharding):
    batchCount = 0
    start_time = time.time()
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch in text_dl:
            if len(batch) % len(jax.devices()) != 0:
              continue  # skip the remaining elements
            input_batch = jnp.array(jnp.array(batch).T)
            target_batch = prep_target_batch(input_batch)
            state, metrics = train_step(state, jax.device_put((input_batch, target_batch), sharding))

            if batchCount % 10 == 0:
                print(f"Loss after batch {batchCount}: {metrics['loss']}")
            batchCount += 1

    end_time = time.time()
    print(f"Completed training model. Total time for training {end_time - start_time} seconds \n")
    return state

def run_training(maxlen, batch_size, datacount):
    print(f"Batch size: {batch_size}, Max length: {maxlen}, Data count: {datacount}")
    #Load the training data
    tiny_stories_dl = load_and_preprocess_data('/data/TinyStories-train.txt', batch_size, maxlen, datacount, tokenizer)
    # Get the Gemma3 model
    model = gm.nn.Gemma3_270M()
    # Load the pretrained parameters
    params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_270M_PT)
    # Create an optimizer
    optimizer = optax.adamw(learning_rate=learning_rate)
    # Define sharding for data parallel training
    mesh = Mesh(jax.devices(), ('batch',))
    sharding = NamedSharding(mesh, PartitionSpec('batch', None))

    # Testing out current state of the model
    test_prompt = "Once upon a time, there was a girl named Amy."
    generate_text(model, params, tokenizer, test_prompt)

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer
    )

    # Perform post training
    print("Start training model")
    state = train_model(state, tiny_stories_dl, num_epochs, sharding)

    # Final text generation
    generate_text(model, state.params, tokenizer, test_prompt)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Gemma model with custom parameters.')
    parser.add_argument('--maxlen', type=int, default=256, help='Maximum sequence length')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--datacount', type=int, default=296000, help='Number of data samples to use')
    args = parser.parse_args()

    run_training(maxlen=args.maxlen, batch_size=args.batch_size, datacount=args.datacount)

In this script, the following applies:

  • A Gemma3Tokenizer converts text data into tokens that the model can process.
  • The load_and_preprocess_data function reads the training data from a file, splits it into individual stories, and uses the tokenizer to convert the text into padded sequences of tokens.
  • The generate_text function takes the model, its parameters, and a prompt to generate text.
  • The train_step function defines a single iteration of training that includes the forward pass, loss calculation (using cross-entropy), gradient computation, and parameter updates.
  • The train_model function iterates through the dataset for a specified number of epochs, which calls the train_step function for each batch.
  • The run_training function orchestrates the entire process to load data, initialize the Gemma 3 model (Gemma3_270M) and optimizer, load pre-trained parameters, set up data sharding for parallel processing, run a test generation, execute the training loop, and perform a final text generation to demonstrate the effect of fine-tuning.
  • The script uses argparse library to accept command-line arguments for maxlen, batch_size, and datacount parameters.

Now that you have explored the fine-tuning script, containerize it to run on GKE.

Containerize the fine-tuning script

Before you run the fine-tuning script in a GKE cluster, you need to containerize it. This tutorial uses a JAX AI image as the base image.

  1. Open the Dockerfile in the same directory as the Gemma3LLMTrain.py file:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    This Dockerfile installs the necessary dependencies and copies the Gemma3LLMTrain.py file into the container.

  2. Build the Docker image and push it to an image repository:

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Replace REPOSITORY_NAME with the name of your Artifact Registry repository.

  3. Add role bindings to the service account:

    export PROJECT_NUMBER=$(gcloud projects describe $PROJECT_ID --format 'get(projectNumber)')
    gcloud artifacts repositories add-iam-policy-binding ${REPOSITORY} \
        --project=${PROJECT_ID} \
        --location=${REGION} \
        --member="serviceAccount:${PROJECT_NUMBER}-compute@developer.gserviceaccount.com" \
        --role="roles/artifactregistry.reader"
    

With the image in the repository, you can now deploy the fine-tuning Job into a GKE cluster.

Deploy the LLM fine-tuning Job

This section shows you how to deploy the LLM fine-tuning Job to your GKE cluster.

  1. Open the training_singlehost.yaml manifest:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "355120"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Apply the manifest:

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

GKE creates a Job that launches a Pod on a TPU Trillium (v6e) node. This Pod runs the Python fine-tuning script, which accesses the fine-tuning data from the specified Cloud Storage bucket mounted at /data path by using Cloud Storage FUSE. The script then fine-tunes the Gemma model.

Monitor the training Job

In this section, you monitor the progress of the fine-tuning Job and its performance.

See fine-tuning progress

  1. List the Pods:

    # Find the Pods
    kubectl get pods
    
  2. Follow the log output:

    kubectl logs -f pods/POD_NAME
    

    Replace POD_NAME with the name of your Pod.

    The output is similar to the following:

    Global device count: 1
    Batch size: 128, Max length: 256, Data count: 96000
    I1028 00:12:55.925999 1387 google_auth_provider.cc:181] Running on GCE, using service account ...
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    Amy lived in a small house. The house was in a big field. Amy liked to play in the big field. She
    Start training model
    Loss after batch 0: 10.25
    Loss after batch 10: 4.3125
    .
    .
    .
    Loss after batch 740: 1.41406
    Completed training model. Total time for training 294.6791355609894 seconds
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    She loved to play with her toys. One day, Amy's mom told her that she had to go to the store to
    
  3. Analyze the output:

    • The Global device count: 1 line indicates the TPU cores used.
    • The model generates reasonable text before this fine-tuning run because it loads from a pre-trained checkpoint.
    • The output generated after fine-tuning shows more resemblance to the start of a short story, indicating the model is learning from the new dataset.
    • Fine-tuning on the full dataset should produce even more refined outputs.

Observe metrics

See the performance of the fine-tuning Job by checking the TPU and CPU metrics. To view observability metrics for your cluster, perform the steps in View cluster and workload observability metrics.

Alternative fine-tuning configurations

This section outlines alternative configurations for your fine-tuning workload.

Model selection

This tutorial used the Gemma3_270M model, which is a small model that fits into a single-host TPU Trillium (v6e) node pool. For larger models that require more memory and compute for fine-tuning, you can use multi-host or multislice node pool configurations.

For a complete list of available models, see the Gemma documentation.

Node pool configurations

This tutorial used a single-host node pool. You can also create multi-host TPU slice node pools or multislice node pools, depending on your needs.

The following tabs show how to create for multi-host and multislice node pools:

Multi-host

  1. In Cloud Shell, run the following command:

    gcloud container node-pools create jax-tpu-multihost1 \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-4t \
        --num-nodes=2 \
        --tpu-topology=2x4 \
        --location=${REGION} \
        --node-locations=${ZONE}
    

    GKE creates a TPU Trillium node pool with a 2x4 topology and two nodes.

  2. Open the training_multihost_jobset.yaml Job definition:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multihost
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                    cloud.google.com/gke-nodepool: jax-tpu-multihost1
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI} 
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
    
  3. Deploy the fine-tuning Job:

    envsubst < training_multihost_jobset.yaml | kubectl apply -f -
    

Multislice

  1. In Cloud Shell, run the following command:

    gcloud container node-pools create jax-tpu-multihost1 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    
    gcloud container node-pools create jax-tpu-multihost2 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    

    GKE creates two TPU Trillium node pools. Each node pool has a 2x4 topology and two nodes.

  2. Open the training_multislice_jobset.yaml Job definition:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multislice
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI}
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:100"
    
  3. Deploy the fine-tuning Job:

    envsubst < training_multislice_jobset.yaml | kubectl apply -f -
    

Performance analysis and optimization

To analyze and optimize the performance of your machine learning fine-tuning, you can use XProf. XProf is a suite of tools that profiles and inspects ML workloads built with JAX, TensorFlow, or PyTorch/XLA. By showing execution traces, memory usage, and other data, XProf lets you fine-tune your models and training setup for better efficiency and faster training.

To analyze the performance of your fine-tuning workload by using XProf, you complete the following steps in this section:

  • Install the xprof package. Modify your training script to start the XProf server.
  • Modify your Kubernetes Job manifest to include a volume mount for XProf logs.
  • Grant the service account permissions to write XProf logs to a Cloud Storage bucket.
  • Run XProf within your Pod and set up port forwarding to access the XProf dashboard.

Install the XProf package

  1. Navigate to the directory that contains the XProf samples:

      cd ai-ml/llm-training-jax-tpu-gemma3/xprof-enabled
    
  2. Build the Docker image and push it to an image repository:

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training-xp"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Replace REPOSITORY_NAME with the name of your Artifact Registry repository.

  3. Run the Dockerfile script:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain equinox
    RUN pip install xprof
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    This Dockerfile installs XProf dependencies.

Copy your fine-tuning script into the container

In this section, create and apply a Kubernetes Job manifest that includes the necessary volume mounts for XProf logs.

  1. Open the training_singlehost.yaml Job definition:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "851200"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
            - name: gcs-fuse-csi-ephemeral2
              mountPath: /xprof
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
          - name: gcs-fuse-csi-ephemeral2
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${XPROF_GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Apply the manifest:

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

Grant the service account permissions to write XProf logs

  1. To enable the service account to write and read, add the "roles/storage.objectUser" role:

    export GSA_NAME="GSA_NAME" # Same as used in initial setup
    
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    
    # Cloud Storage Bucket details
    export XPROF_GCS_BUCKET_NAME="XPROF_GCS_BUCKET_NAME"
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    
    gcloud storage buckets add-iam-policy-binding "gs://${XPROF_GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectUser" \
        --project="${PROJECT_ID}"
    

    Replace the following:

    • GSA_NAME: the name of the Google Service Account to which to grant the role.
    • XPROF_GCS_BUCKET_NAME: the name of the bucket to which to grant the role.
  2. Run XProf within your Pod:

    kubectl exec POD_NAME -c training-container -it -- bash # exec into the container
    xprof --port 9001 --logdir /xprof # start xprof
    

    Replace POD_NAME with the name of your Pod.

Access the XProf dashboard

  1. Set up port forwarding to the XProf server in the Pod:

    kubectl port-forward POD_NAME 9001:9001
    
  2. In your browser's address bar, enter the following:

    http://localhost:9001/
    

    The XProf Trace Viewer opens.

  3. In the TensorBoard window, click Capture profile.

  4. In the Profile Service URL(s) or TPU name field, enter localhost:9002.

  5. To capture more details, in the Host Trace (TraceMe) Level, select verbose and enable Python trace logging.

  6. To view the dashboard, click Capture.

    TensorBoard captures the profile and lets you analyze the performance of the training script. The graph shows the execution timeline for both TPU and CPU performance profiles:

An example of the XProf trace viewer that shows a performance matrix graph

For more profiling options to analyze your training workload performance, see the JAX documentation on Profiling computation.

Fine-tuning in production environments

This tutorial showed you how to test test JAX-based training in a distributed environment. For optimized LLM fine-tuning in production, use the Maxtext library. If you are interested in diffusion models, use Maxdiffusion implementations.

For long-running training or fine-tuning workloads in production, set up workload checkpointing to minimize progress loss during a failure. To learn more about setting up multi-tier checkpointing, see Train large-scale machine learning models on GKE with Multi-Tier Checkpointing.

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

Delete the individual resources

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources by running the following commands:

  1. Delete the resources you created in this tutorial:

    gcloud container clusters delete ${CLUSTER_NAME} --location=${REGION}
    gcloud storage rm --recursive gs://${GCS_BUCKET_NAME}
    gcloud artifacts docker images delete ${IMAGE_URI} --delete-tags
    
  2. If you don't need the data generated by XProf, remove the Cloud Storage bucket used by XProf:

    gcloud storage rm --recursive gs://${XPROF_GCS_BUCKET_NAME}
    

What's next