Menyesuaikan LLM menggunakan TPU di GKE dengan JAX

Tutorial ini menunjukkan cara menyesuaikan model bahasa besar (LLM) menggunakan Unit Pemrosesan Tensor (TPU) di Google Kubernetes Engine (GKE) dengan JAX. Dengan penyesuaian, Anda dapat mengadaptasi model dasar seperti Gemma 3 ke domain atau tugas tertentu. Proses ini meningkatkan presisi dan akurasi model dengan memperbarui parameternya menggunakan set data khusus Anda sendiri.

Panduan ini adalah titik awal yang baik jika Anda memerlukan kontrol terperinci, penyesuaian, skalabilitas, ketahanan, portabilitas, dan efektivitas biaya Kubernetes terkelola saat melakukan penyesuaian beban kerja AI/ML.

Latar belakang

Dengan menggunakan TPU di GKE dengan Jax untuk menyempurnakan LLM, Anda dapat membangun solusi penyempurnaan yang tangguh dan siap produksi dengan semua manfaat Kubernetes terkelola.

Gemma

Gemma adalah sekumpulan model multimodal AI/ML generatif yang ringan dan tersedia secara terbuka, yang dirilis dengan lisensi terbuka. Model AI ini tersedia untuk dijalankan di aplikasi, hardware, perangkat seluler, atau layanan yang dihosting. Gemma 3 memperkenalkan multimodality, dan mendukung input vision-language dan output teks. Model ini menangani jendela konteks hingga 128.000 token dan mendukung lebih dari 140 bahasa. Gemma 3 juga menawarkan peningkatan kemampuan matematika, penalaran, dan chat, termasuk output terstruktur dan panggilan fungsi.

Anda dapat menggunakan model Gemma untuk pembuatan teks, atau Anda juga dapat menyesuaikan model ini untuk tugas khusus.

Untuk mengetahui informasi selengkapnya, lihat dokumentasi Gemma.

TPU

TPU adalah sirkuit terintegrasi khusus aplikasi (ASIC) yang dikembangkan khusus oleh Google untuk mempercepat model machine learning dan AI yang dibangun menggunakan framework seperti TensorFlow, PyTorch, dan JAX.

Sebelum menggunakan TPU di GKE, sebaiknya selesaikan jalur pembelajaran berikut:

  1. Pelajari ketersediaan versi TPU saat ini dengan arsitektur sistem Cloud TPU.
  2. Pelajari TPU di GKE.

JAX

JAX adalah framework machine learning berperforma tinggi yang dirancang untuk digunakan dengan TPU dan GPU. JAX menyediakan API untuk membangun dan melatih model machine learning.

Untuk mempelajari lebih lanjut, lihat repositori JAX.

Tujuan

Tutorial ini membahas langkah-langkah berikut:

  1. Buat cluster GKE Autopilot atau Standard dengan topologi TPU yang direkomendasikan, berdasarkan karakteristik model. Selama tutorial ini, Anda akan melakukan penyesuaian pada node pool host tunggal.
  2. Tambahkan data ke bucket Cloud Storage dan pasang ke container melalui Cloud Storage FUSE.
  3. Deploy Tugas penyesuaian LLM di GKE.
  4. Pantau Tugas penyesuaian dan lihat log.

Sebelum memulai

  • Login ke akun Google Cloud Anda. Jika Anda baru menggunakan Google Cloud, buat akun untuk mengevaluasi performa produk kami dalam skenario dunia nyata. Pelanggan baru juga mendapatkan kredit gratis senilai $300 untuk menjalankan, menguji, dan men-deploy workload.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • Pastikan Anda memiliki peran berikut di project: roles/container.admin,roles/iam.serviceAccountAdmin,roles/storage.admin

    Memeriksa peran

    1. Di konsol Google Cloud , buka halaman IAM.

      Buka IAM
    2. Pilih project.
    3. Di kolom Akun utama, temukan semua baris yang mengidentifikasi Anda atau grup yang Anda ikuti. Untuk mengetahui grup mana saja yang Anda ikuti, hubungi administrator Anda.

    4. Untuk semua baris yang menentukan atau menyertakan Anda, periksa kolom Peran untuk melihat apakah daftar peran menyertakan peran yang diperlukan.

    Memberikan peran

    1. Di konsol Google Cloud , buka halaman IAM.

      Buka IAM
    2. Pilih project.
    3. Klik Grant access.
    4. Di kolom New principals, masukkan ID pengguna Anda. Biasanya, ini adalah alamat email untuk Akun Google.

    5. Klik Pilih peran, lalu telusuri peran.
    6. Untuk memberikan peran tambahan, klik Add another role, lalu tambahkan tiap peran tambahan.
    7. Klik Simpan.
  • Pastikan Anda memiliki kuota yang cukup untuk chip TPU Trillium (v6e) sebanyak 16 unit. Dalam tutorial ini, Anda menggunakan konfigurasi kumpulan node yang memerlukan 16 chip dan instance on-demand.
  • Pastikan Anda memiliki repositori Docker. Jika Anda belum memilikinya, buat repositori standar di Artifact Registry.

Menyiapkan lingkungan

Dalam tutorial ini, Anda akan menggunakan Cloud Shell untuk mengelola resource yang dihosting di Google Cloud. Cloud Shell telah diinstal sebelumnya dengan software yang Anda perlukan untuk tutorial ini, termasuk kubectl dan Google Cloud CLI.

Untuk menyiapkan lingkungan Anda dengan Cloud Shell, ikuti langkah-langkah berikut:

  1. Di konsol Google Cloud , luncurkan sesi Cloud Shell dan klik Ikon aktivasi Cloud Shell Activate Cloud Shell. Tindakan ini akan meluncurkan sesi di panel bawah konsol Google Cloud .

  2. Tetapkan variabel lingkungan default:

    gcloud config set project PROJECT_ID
    gcloud config set billing/quota_project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=CONTROL_PLANE_LOCATION
    export ZONE=ZONE
    export GCS_BUCKET_NAME=BUCKET_NAME
    

    Ganti nilai berikut:

    • PROJECT_ID: Google Cloud Project ID Anda.
    • CLUSTER_NAME: nama cluster GKE Anda.
    • CONTROL_PLANE_LOCATION: region Compute Engine tempat cluster GKE dan node TPU Anda berada. Region harus berisi zona tempat jenis mesin TPU Trillium (v6e) tersedia.
    • ZONE: zona dalam region CONTROL_PLANE_LOCATION yang Anda pilih tempat jenis mesin TPU Trillium (v6e) tersedia. Untuk mencantumkan zona tempat TPU Trillium (v6e) tersedia, jalankan perintah berikut:

        gcloud compute accelerator-types list --filter="name~ct6e" --format="value(zone)"
      
    • BUCKET_NAME: nama bucket Cloud Storage yang berisi data pelatihan Anda.

  3. Clone repositori contoh:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    cd kubernetes-engine-samples
    
  4. Buka direktori kerja:

    cd ai-ml/llm-training-jax-tpu-gemma3
    

Membuat dan mengonfigurasi resource Google Cloud

Di bagian ini, Anda akan membuat dan mengonfigurasi Google Cloud resource.

Membuat cluster GKE

Anda dapat menyetel LLM di TPU dalam cluster GKE Autopilot atau Standard. Sebaiknya gunakan cluster Autopilot untuk mendapatkan pengalaman Kubernetes yang terkelola sepenuhnya. Untuk memilih mode operasi GKE yang paling sesuai untuk workload Anda, lihat Memilih mode operasi GKE.

Autopilot

Buat cluster Autopilot GKE yang menggunakan Workload Identity Federation for GKE dan telah mengaktifkan Cloud Storage FUSE.

gcloud container clusters create-auto ${CLUSTER_NAME} \
    --location=${REGION}

Pembuatan cluster mungkin memerlukan waktu beberapa menit.

Standar

  1. Buat cluster Standar GKE regional yang menggunakan Workload Identity Federation for GKE dan telah mengaktifkan Cloud Storage FUSE.

    gcloud container clusters create ${CLUSTER_NAME} \
        --enable-ip-alias \
        --addons GcsFuseCsiDriver \
        --machine-type=n2-standard-4 \
        --num-nodes=2 \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --location=${REGION}
    

    Pembuatan cluster mungkin memerlukan waktu beberapa menit.

  2. Buat node pool host tunggal:

    gcloud container node-pools create jax-tpu-nodepool \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-1t \
        --num-nodes=1 \
        --location=${REGION} \
        --node-locations=${ZONE} \
        --workload-metadata=GKE_METADATA
    

GKE membuat node pool TPU Trillium dengan topologi 1x1 dan satu node. Flag --workload-metadata=GKE_METADATA mengonfigurasi node pool untuk menggunakan server metadata GKE.

Instal JobSet

  1. Konfigurasi kubectl untuk berkomunikasi dengan cluster Anda:

    gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}
    
  2. Instal JobSet versi terbaru yang dirilis:

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    Ganti JOBSET_VERSION dengan versi JobSet yang dirilis terbaru. Misalnya, v0.11.0.

  3. Verifikasi penginstalan JobSet:

    kubectl get pods -n jobset-system
    

    Outputnya mirip dengan hal berikut ini:

    NAME                                         READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-6c56668494-l4dhc   1/1     Running   0          4m45s
    

    Anda mungkin perlu menambahkan lebih banyak node jika JobSet menunggu resource.

Mengonfigurasi Cloud Storage FUSE

Untuk menyesuaikan LLM, Anda perlu memberikan data pelatihan. Dalam tutorial ini, Anda akan menggunakan set data TinyStories dari Hugging Face. Set data ini berisi cerita pendek yang dibuat secara sintetis oleh GPT-3.5 dan GPT-4, yang menggunakan kosakata terbatas.

Bagian ini membahas langkah-langkah untuk mengonfigurasi Cloud Storage FUSE agar dapat membaca data dari bucket Cloud Storage.

  1. Download set data:

    wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
    
  2. Upload data ke bucket Cloud Storage baru:

    gcloud storage buckets create gs://${GCS_BUCKET_NAME} \
        --location=${REGION} \
        --enable-hierarchical-namespace \
        --uniform-bucket-level-access
    gcloud storage cp TinyStories-train.txt gs://${GCS_BUCKET_NAME}
    
  3. Agar workload Anda dapat membaca data melalui Cloud Storage FUSE, buat akun layanan Kubernetes (KSA) dan tambahkan izin yang diperlukan. Jalankan skrip permissionsetup.sh:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    #!/bin/bash
    
    # --- Configuration Variables ---
    # Kubernetes Service Account details
    export KSA_NAME="jaxserviceaccout"
    export NAMESPACE="default"
    
    # Google Cloud IAM Service Account details
    export GSA_NAME="<GSA_NAME>"
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    export  GSA_DESCRIPTION="GKE Service Account to read GCS bucket for ${KSA_NAME}"
    
    # GCS Bucket details
    export GCS_BUCKET_NAME="<GCS_BUCKET_NAME>" # <--- IMPORTANT: Update this to your bucket name
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    export WI_MEMBER="serviceAccount:${PROJECT_ID}.svc.id.goog[${NAMESPACE}/${KSA_NAME}]"
    
    # --- Check if PROJECT_ID is set ---
    if [ -z "${PROJECT_ID}" ]; then
      echo "Error: PROJECT_ID is not set. Please set it using 'gcloud config set project YOUR_PROJECT_ID'"
      exit 1
    fi
    
    echo "--- Configuration ---"
    echo "KSA_NAME:      ${KSA_NAME}"
    echo "NAMESPACE:     ${NAMESPACE}"
    echo "GSA_NAME:      ${GSA_NAME}"
    echo "PROJECT_ID:    ${PROJECT_ID}"
    echo "GSA_EMAIL:     ${GSA_EMAIL}"
    echo "GCS_BUCKET_NAME:   ${GCS_BUCKET_NAME}"
    echo "WI_MEMBER:     ${WI_MEMBER}"
    echo "--------------------"
    read -p "Press enter to continue..."
    
    # --- Command Execution ---
    
    echo "[1/5] Creating Google Cloud IAM Service Account (GSA): ${GSA_NAME}"
    gcloud iam service-accounts create "${GSA_NAME}" \
        --project="${PROJECT_ID}" \
        --description="${GSA_DESCRIPTION}" \
        --display-name="${GSA_NAME}"
    
    echo "[2/5] Granting GSA '${GSA_EMAIL}' read access (roles/storage.objectViewer) to bucket 'gs://${GCS_BUCKET_NAME}'"
    gcloud storage buckets add-iam-policy-binding "gs://${GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectViewer" \
        --project="${PROJECT_ID}"
    
    echo "[3/5] Creating Kubernetes Service Account (KSA): ${KSA_NAME} in namespace ${NAMESPACE}"
    kubectl create serviceaccount "${KSA_NAME}" --namespace "${NAMESPACE}"
    
    echo "[4/5] Allowing KSA to impersonate GSA (Workload Identity Binding): ${GSA_EMAIL}"
    gcloud iam service-accounts add-iam-policy-binding "${GSA_EMAIL}" \
        --role roles/iam.workloadIdentityUser \
        --member "${WI_MEMBER}" \
        --project="${PROJECT_ID}"
    
    echo "[5/5] Annotating KSA '${KSA_NAME}' to link with GSA '${GSA_EMAIL}'"
    kubectl annotate serviceaccount "${KSA_NAME}" \
        --namespace "${NAMESPACE}" \
        iam.gke.io/gcp-service-account="${GSA_EMAIL}"
    
    echo "--- Setup Complete ---"
    echo "Pods in namespace '${NAMESPACE}' using serviceAccount '${KSA_NAME}' can now authenticate as '${GSA_EMAIL}' and have read access to 'gs://${GCS_BUCKET_NAME}'."
    

    Setelah Anda menjalankan skrip ini, resource berikut akan dikonfigurasi di projectGoogle Cloud dan cluster GKE Anda:

    • Akun layanan IAM baru bernama gcs-fuse-sa dibuat di project Anda.
    • Google Cloud Akun Layanan (GSA) (gcs-fuse-sa) yang dibuat diberi peran roles/storage.objectViewer di bucket Cloud Storage yang ditentukan oleh ${GCS_BUCKET_NAME}. Izin ini memungkinkan GSA membaca objek dari bucket.
    • KSA baru bernama jaxserviceaccount dibuat di namespace default dalam cluster GKE Anda.
    • Kebijakan IAM GSA diperbarui untuk memberikan peran roles/iam.workloadIdentityUser kepada KSA. Izin ini memungkinkan KSA meniru identitas GSA.
    • KSA diberi anotasi untuk menautkannya ke GSA. Anotasi ini memberi tahu GKE GSA mana yang harus di-impersonate oleh KSA menggunakan Workload Identity.

      Setiap Pod yang berjalan di namespace default cluster GKE Anda yang menggunakan akun layanan jaxserviceaccount kini akan dapat melakukan autentikasi sebagai GSA gcs-fuse-sa. Pod ini akan memiliki akses baca ke objek yang disimpan di bucket gs://${GCS_BUCKET_NAME}, yang penting agar Tugas penyesuaian dapat mengakses set data menggunakan Cloud Storage FUSE.

Buat skrip penyesuaian

Di bagian ini, Anda akan mempelajari skrip pelatihan yang melakukan operasi penyesuaian pada model Gemma 3. Skrip ini menggunakan Gemma3Tokenizer.

Tinjau skrip penyesuaian Gemma3LLMTrain.py berikut:

# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import grain.python as pygrain
import jax
import jax.numpy as jnp
import optax
import pandas as pd
import time
import argparse

from dataclasses import dataclass
from functools import partial
from gemma import gm
from flax.training import train_state
from jax.sharding import Mesh, PartitionSpec, NamedSharding

jax.distributed.initialize()
print("Global device count:", jax.device_count())
print("jax version:", jax.__version__)

tokenizer = gm.text.Gemma3Tokenizer()
num_epochs = 1
learning_rate = 2e-5

@dataclass
class TextDataset:
    data: list
    maxlen: int
    tokenizer: gm.text.Gemma3Tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        encoding = self.tokenizer.encode(self.data[idx])[:self.maxlen]  # Tokenize and truncate
        return encoding + [0] * (self.maxlen - len(encoding))  # Pad to maxlen

def load_and_preprocess_data(file_path, batch_size, maxlen, datacount, tokenizer):

    with open(file_path, 'r') as f:
      text = f.read()

    stories = text.split('<|endoftext|>')
    stories = [story for story in stories if story.strip()][:datacount]
    df = pd.DataFrame({'text': stories})
    data = df['text'].dropna().tolist()
    dataset = TextDataset(data, maxlen, tokenizer)

    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=False,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )

    dataloader = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )

    return dataloader

def generate_text(model, params, tokenizer, prompt):
    sampler = gm.text.Sampler(
        model=model,
        params=params,
        tokenizer=tokenizer,
    )
    print("Generating response for: " + prompt)
    out = sampler.sample(prompt, max_new_tokens=32)
    print("Reponse: \n" + out + "\n")
    return out

prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))

@partial(jax.jit, donate_argnums=(0,))
def train_step(state, batch):
    """Performs one supervised fine-tuning step."""

    def loss_fn(params):
        # Run the forward pass. The model returns logits.
        logits = state.apply_fn({'params': params}, batch[0]).logits

        # Calculate the cross-entropy loss.
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch[1]
        ).mean()

        return loss

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)

    # Update the model state
    state = state.apply_gradients(grads=grads)

    metrics = {'loss': loss}
    return state, metrics

def train_model(state, text_dl, num_epochs, sharding):
    batchCount = 0
    start_time = time.time()
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch in text_dl:
            if len(batch) % len(jax.devices()) != 0:
              continue  # skip the remaining elements
            input_batch = jnp.array(jnp.array(batch).T)
            target_batch = prep_target_batch(input_batch)
            state, metrics = train_step(state, jax.device_put((input_batch, target_batch), sharding))

            if batchCount % 10 == 0:
                print(f"Loss after batch {batchCount}: {metrics['loss']}")
            batchCount += 1

    end_time = time.time()
    print(f"Completed training model. Total time for training {end_time - start_time} seconds \n")
    return state

def run_training(maxlen, batch_size, datacount):
    print(f"Batch size: {batch_size}, Max length: {maxlen}, Data count: {datacount}")
    #Load the training data
    tiny_stories_dl = load_and_preprocess_data('/data/TinyStories-train.txt', batch_size, maxlen, datacount, tokenizer)
    # Get the Gemma3 model
    model = gm.nn.Gemma3_270M()
    # Load the pretrained parameters
    params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_270M_PT)
    # Create an optimizer
    optimizer = optax.adamw(learning_rate=learning_rate)
    # Define sharding for data parallel training
    mesh = Mesh(jax.devices(), ('batch',))
    sharding = NamedSharding(mesh, PartitionSpec('batch', None))

    # Testing out current state of the model
    test_prompt = "Once upon a time, there was a girl named Amy."
    generate_text(model, params, tokenizer, test_prompt)

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer
    )

    # Perform post training
    print("Start training model")
    state = train_model(state, tiny_stories_dl, num_epochs, sharding)

    # Final text generation
    generate_text(model, state.params, tokenizer, test_prompt)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Gemma model with custom parameters.')
    parser.add_argument('--maxlen', type=int, default=256, help='Maximum sequence length')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--datacount', type=int, default=296000, help='Number of data samples to use')
    args = parser.parse_args()

    run_training(maxlen=args.maxlen, batch_size=args.batch_size, datacount=args.datacount)

Dalam skrip ini, hal berikut berlaku:

  • Gemma3Tokenizer mengonversi data teks menjadi token yang dapat diproses oleh model.
  • Fungsi load_and_preprocess_data membaca data pelatihan dari file, membaginya menjadi cerita individual, dan menggunakan tokenizer untuk mengonversi teks menjadi urutan token yang diisi.
  • Fungsi generate_text menggunakan model, parameternya, dan perintah untuk membuat teks.
  • Fungsi train_step menentukan satu iterasi pelatihan yang mencakup penerusan, penghitungan kerugian (menggunakan entropi silang), penghitungan gradien, dan pembaruan parameter.
  • Fungsi train_model melakukan iterasi melalui set data untuk sejumlah epoch yang ditentukan, yang memanggil fungsi train_step untuk setiap batch.
  • Fungsi run_training mengatur seluruh proses untuk memuat data, menginisialisasi model Gemma 3 (Gemma3_270M) dan pengoptimal, memuat parameter yang telah dilatih sebelumnya, menyiapkan sharding data untuk pemrosesan paralel, menjalankan pembuatan pengujian, menjalankan loop pelatihan, dan melakukan pembuatan teks akhir untuk menunjukkan efek penyesuaian.
  • Skrip ini menggunakan library argparse untuk menerima argumen command line untuk parameter maxlen, batch_size, dan datacount.

Setelah mempelajari skrip penyesuaian, buat skrip tersebut dalam container untuk dijalankan di GKE.

Menyimpan skrip penyesuaian ke dalam container

Sebelum menjalankan skrip penyesuaian di cluster GKE, Anda harus membuatnya dalam container. Tutorial ini menggunakan gambar AI JAX sebagai gambar dasar.

  1. Buka Dockerfile di direktori yang sama dengan file Gemma3LLMTrain.py:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    Dockerfile ini menginstal dependensi yang diperlukan dan menyalin file Gemma3LLMTrain.py ke dalam container.

  2. Bangun image Docker dan kirimkan ke repositori image:

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Ganti REPOSITORY_NAME dengan nama repositori Artifact Registry Anda.

  3. Tambahkan binding peran ke akun layanan:

    export PROJECT_NUMBER=$(gcloud projects describe $PROJECT_ID --format 'get(projectNumber)')
    gcloud artifacts repositories add-iam-policy-binding ${REPOSITORY} \
        --project=${PROJECT_ID} \
        --location=${REGION} \
        --member="serviceAccount:${PROJECT_NUMBER}-compute@developer.gserviceaccount.com" \
        --role="roles/artifactregistry.reader"
    

Dengan gambar di repositori, Anda kini dapat men-deploy Tugas penyesuaian ke cluster GKE.

Men-deploy Tugas fine-tuning LLM

Bagian ini menunjukkan cara men-deploy Tugas penyesuaian LLM ke cluster GKE Anda.

  1. Buka manifes training_singlehost.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "355120"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Terapkan manifes:

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

GKE membuat Tugas yang meluncurkan Pod di node TPU Trillium (v6e). Pod ini menjalankan skrip penyesuaian Python, yang mengakses data penyesuaian dari bucket Cloud Storage yang ditentukan yang dipasang di jalur /data menggunakan Cloud Storage FUSE. Kemudian, skrip akan melakukan penyesuaian pada model Gemma.

Memantau Tugas pelatihan

Di bagian ini, Anda akan memantau progres Tugas penyesuaian dan performanya.

Melihat progres penyesuaian

  1. Mencantumkan Pod:

    # Find the Pods
    kubectl get pods
    
  2. Ikuti output log:

    kubectl logs -f pods/POD_NAME
    

    Ganti POD_NAME dengan nama Pod Anda.

    Outputnya mirip dengan hal berikut ini:

    Global device count: 1
    Batch size: 128, Max length: 256, Data count: 96000
    I1028 00:12:55.925999 1387 google_auth_provider.cc:181] Running on GCE, using service account ...
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    Amy lived in a small house. The house was in a big field. Amy liked to play in the big field. She
    Start training model
    Loss after batch 0: 10.25
    Loss after batch 10: 4.3125
    .
    .
    .
    Loss after batch 740: 1.41406
    Completed training model. Total time for training 294.6791355609894 seconds
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    She loved to play with her toys. One day, Amy's mom told her that she had to go to the store to
    
  3. Analisis output:

    • Baris Global device count: 1 menunjukkan core TPU yang digunakan.
    • Model menghasilkan teks yang wajar sebelum menjalankan fine-tuning ini karena model dimuat dari checkpoint yang telah dilatih sebelumnya.
    • Output yang dihasilkan setelah penyesuaian menunjukkan lebih banyak kemiripan dengan awal cerita pendek, yang menunjukkan bahwa model sedang belajar dari set data baru.
    • Penyesuaian pada set data lengkap akan menghasilkan output yang lebih baik.

Mengamati metrik

Lihat performa Tugas penyesuaian dengan memeriksa metrik TPU dan CPU. Untuk melihat metrik kemampuan observasi untuk cluster Anda, lakukan langkah-langkah di Melihat metrik kemampuan observasi cluster dan workload.

Konfigurasi penyesuaian alternatif

Bagian ini menguraikan konfigurasi alternatif untuk workload penyesuaian Anda.

Pemilihan model

Tutorial ini menggunakan model Gemma3_270M, yang merupakan model kecil yang cocok ke dalam node pool TPU Trillium (v6e) host tunggal. Untuk model yang lebih besar yang memerlukan lebih banyak memori dan komputasi untuk penyesuaian, Anda dapat menggunakan konfigurasi node pool multi-host atau multislice.

Untuk mengetahui daftar lengkap model yang tersedia, lihat dokumentasi Gemma.

Konfigurasi kumpulan node

Tutorial ini menggunakan node pool host tunggal. Anda juga dapat membuat node pool slice TPU multi-host atau node pool multislice, bergantung pada kebutuhan Anda.

Tab berikut menunjukkan cara membuat kumpulan node multi-host dan multiris:

Multi-host

  1. Jalankan perintah berikut di Cloud Shell:

    gcloud container node-pools create jax-tpu-multihost1 \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-4t \
        --num-nodes=2 \
        --tpu-topology=2x4 \
        --location=${REGION} \
        --node-locations=${ZONE}
    

    GKE membuat node pool TPU Trillium dengan topologi 2x4 dan dua node.

  2. Buka definisi Tugas training_multihost_jobset.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multihost
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                    cloud.google.com/gke-nodepool: jax-tpu-multihost1
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI} 
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
    
  3. Deploy Tugas penyesuaian:

    envsubst < training_multihost_jobset.yaml | kubectl apply -f -
    

Multislice

  1. Jalankan perintah berikut di Cloud Shell:

    gcloud container node-pools create jax-tpu-multihost1 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    
    gcloud container node-pools create jax-tpu-multihost2 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    

    GKE membuat dua node pool TPU Trillium. Setiap node pool memiliki topologi 2x4 dan dua node.

  2. Buka definisi Tugas training_multislice_jobset.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multislice
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI}
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:100"
    
  3. Deploy Tugas penyesuaian:

    envsubst < training_multislice_jobset.yaml | kubectl apply -f -
    

Analisis dan pengoptimalan performa

Untuk menganalisis dan mengoptimalkan performa penyesuaian model machine learning, Anda dapat menggunakan XProf. XProf adalah rangkaian alat yang memprofilkan dan memeriksa beban kerja ML yang dibuat dengan JAX, TensorFlow, atau PyTorch/XLA. Dengan menampilkan rekaman eksekusi, penggunaan memori, dan data lainnya, XProf memungkinkan Anda menyesuaikan model dan konfigurasi pelatihan untuk efisiensi yang lebih baik dan pelatihan yang lebih cepat.

Untuk menganalisis performa workload penyesuaian Anda menggunakan XProf, Anda harus menyelesaikan langkah-langkah berikut di bagian ini:

  • Instal paket xprof. Ubah skrip pelatihan untuk memulai server XProf.
  • Ubah manifes Job Kubernetes Anda untuk menyertakan pemasangan volume untuk log XProf.
  • Beri akun layanan izin untuk menulis log XProf ke bucket Cloud Storage.
  • Jalankan XProf dalam Pod Anda dan siapkan penerusan port untuk mengakses dasbor XProf.

Instal paket XProf

  1. Buka direktori yang berisi sampel XProf:

      cd ai-ml/llm-training-jax-tpu-gemma3/xprof-enabled
    
  2. Bangun image Docker dan kirimkan ke repositori image:

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training-xp"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Ganti REPOSITORY_NAME dengan nama repositori Artifact Registry Anda.

  3. Jalankan skrip Dockerfile:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain equinox
    RUN pip install xprof
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    Dockerfile ini menginstal dependensi XProf.

Salin skrip penyesuaian ke dalam container

Di bagian ini, buat dan terapkan manifes Job Kubernetes yang mencakup pemasangan volume yang diperlukan untuk log XProf.

  1. Buka definisi Tugas training_singlehost.yaml:

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "851200"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
            - name: gcs-fuse-csi-ephemeral2
              mountPath: /xprof
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
          - name: gcs-fuse-csi-ephemeral2
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${XPROF_GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Terapkan manifes:

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

Beri akun layanan izin untuk menulis log XProf

  1. Untuk mengizinkan akun layanan menulis dan membaca, tambahkan peran "roles/storage.objectUser":

    export GSA_NAME="GSA_NAME" # Same as used in initial setup
    
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    
    # Cloud Storage Bucket details
    export XPROF_GCS_BUCKET_NAME="XPROF_GCS_BUCKET_NAME"
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    
    gcloud storage buckets add-iam-policy-binding "gs://${XPROF_GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectUser" \
        --project="${PROJECT_ID}"
    

    Ganti kode berikut:

    • GSA_NAME: nama Akun Layanan Google yang akan diberi peran.
    • XPROF_GCS_BUCKET_NAME: nama bucket tempat memberikan peran.
  2. Jalankan XProf di dalam Pod Anda:

    kubectl exec POD_NAME -c training-container -it -- bash # exec into the container
    xprof --port 9001 --logdir /xprof # start xprof
    

    Ganti POD_NAME dengan nama Pod Anda.

Mengakses dasbor XProf

  1. Siapkan penerusan port ke server XProf di Pod:

    kubectl port-forward POD_NAME 9001:9001
    
  2. Di kolom URL browser Anda, masukkan kode berikut:

    http://localhost:9001/
    

    XProf Trace Viewer akan terbuka.

  3. Di jendela TensorBoard, klik Rekam profil.

  4. Di kolom URL Layanan Profil atau nama TPU, masukkan localhost:9002.

  5. Untuk merekam lebih banyak detail, di Host Trace (TraceMe) Level, pilih verbose dan aktifkan logging trace Python.

  6. Untuk melihat dasbor, klik Ambil.

    TensorBoard mengambil profil dan memungkinkan Anda menganalisis performa skrip pelatihan. Grafik menunjukkan linimasa eksekusi untuk profil performa TPU dan CPU:

Contoh pelihat rekaman aktivitas XProf yang menampilkan grafik matriks performa

Untuk opsi pembuatan profil lainnya guna menganalisis performa workload pelatihan, lihat dokumentasi JAX tentang Pembuatan profil komputasi.

Penyesuaian di lingkungan produksi

Tutorial ini menunjukkan cara menguji pelatihan berbasis JAX dalam lingkungan terdistribusi. Untuk penyempurnaan LLM yang dioptimalkan dalam produksi, gunakan library Maxtext. Jika Anda tertarik dengan model difusi, gunakan penerapan Maxdiffusion.

Untuk workload pelatihan atau penyesuaian yang berjalan lama dalam produksi, siapkan checkpointing workload untuk meminimalkan hilangnya progres selama terjadi kegagalan. Untuk mempelajari lebih lanjut cara menyiapkan checkpoint multi-tingkat, lihat Melatih model machine learning berskala besar di GKE dengan Checkpointing Multi-Tingkat.

Pembersihan

Agar tidak perlu membayar biaya pada akun Google Cloud Anda untuk resource yang digunakan dalam tutorial ini, hapus project yang berisi resource tersebut, atau simpan project dan hapus setiap resource.

Menghapus resource satu per satu

Agar akun Google Cloud Anda tidak dikenai biaya untuk resource yang digunakan dalam tutorial ini, hapus project yang berisi resource tersebut, atau simpan project dan hapus setiap resource dengan menjalankan perintah berikut:

  1. Hapus resource yang Anda buat dalam tutorial ini:

    gcloud container clusters delete ${CLUSTER_NAME} --location=${REGION}
    gcloud storage rm --recursive gs://${GCS_BUCKET_NAME}
    gcloud artifacts docker images delete ${IMAGE_URI} --delete-tags
    
  2. Jika Anda tidak memerlukan data yang dihasilkan oleh XProf, hapus bucket Cloud Storage yang digunakan oleh XProf:

    gcloud storage rm --recursive gs://${XPROF_GCS_BUCKET_NAME}
    

Langkah berikutnya