GKE で JAX を使用して TPU で LLM をファインチューニングする

このチュートリアルでは、Google Kubernetes Engine(GKE)で Tensor Processing Unit(TPU)を使用して JAX で大規模言語モデル(LLM)をファインチューニングする方法について説明します。ファインチューニングでは、Gemma 3 などの基盤モデルを特定のドメインやタスクに適応させることができます。このプロセスでは、独自の特殊なデータセットでパラメータを更新することで、モデルの適合率と精度が向上します。

このガイドは、AI/ML ワークロードをファインチューニングする際に、マネージド Kubernetes での詳細な制御、カスタマイズ、スケーラビリティ、復元力、ポータビリティ、費用対効果が求められる場合の出発点として適しています。

背景

GKE で TPU を使用して Jax で LLM をファインチューニングすることで、マネージド Kubernetes のメリットをすべて活用した、本番環境対応の堅牢なファインチューニング ソリューションを構築できます。

Gemma

Gemma は、オープン ライセンスでリリースされて一般公開されている、軽量の生成 AI/ML マルチモーダル モデルのセットです。これらの AI モデルは、アプリケーション、ハードウェア、モバイル デバイス、ホスト型サービスで実行できます。Gemma 3 ではマルチモダリティが導入され、ビジョン言語入力とテキスト出力がサポートされています。最大 128,000 トークンのコンテキスト ウィンドウを処理でき、140 を超える言語に対応しています。また Gemma 3 では、構造化出力や関数呼び出しなど、数学、推論、チャット関連の機能が強化されています。

Gemma モデルはテキスト生成に使用できますが、特殊なタスク用にチューニングすることもできます。

詳細については、Gemma のドキュメントをご覧ください。

TPU

TPU は、TensorFlowPyTorchJAX などのフレームワークを使用して構築された ML モデルと AI モデルを高速化するために、Google が独自に開発した特定用途向け集積回路(ASIC)です。

GKE で TPU を使用する前に、次の学習プログラムを完了することをおすすめします。

  1. Cloud TPU システム アーキテクチャで、現在の TPU バージョンの可用性について学習する。
  2. GKE の TPU についてを確認する。

JAX

JAX は、TPU と GPU で使用するように設計された高パフォーマンスの ML フレームワークです。JAX は、ML モデルの構築とトレーニング用の API を提供します。

詳細については、JAX リポジトリをご覧ください。

目標

このチュートリアルでは、次の手順について説明します。

  1. モデルの特性に基づいて推奨される TPU トポロジを持つ GKE Autopilot または Standard クラスタを作成します。このチュートリアルでは、単一ホスト ノードプールでファインチューニングを行います。
  2. Cloud Storage バケットにデータを追加し、Cloud Storage FUSE を介してコンテナにマウントします。
  3. GKE に LLM ファインチューニング ジョブをデプロイします。
  4. ファインチューニング ジョブをモニタリングし、ログを表示します。

始める前に

  • Google Cloud アカウントにログインします。 Google Cloudを初めて使用する場合は、 アカウントを作成して、実際のシナリオでの Google プロダクトのパフォーマンスを評価してください。新規のお客様には、ワークロードの実行、テスト、デプロイができる無料クレジット $300 分を差し上げます。
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • プロジェクトに次のロール(複数の場合あり)が割り当てられていることを確認します。 roles/container.admin、roles/iam.serviceAccountAdmin、roles/storage.admin

    ロールを確認する

    1. Google Cloud コンソールで、[IAM] ページに移動します。

      IAM に移動
    2. プロジェクトを選択します。
    3. [プリンシパル] 列で、自分または自分が所属するグループの行をすべて確認します。所属するグループについては、管理者にお問い合わせください。

    4. 自分のメールアドレスを含む行の [ロール] 列で、ロールのリストに必要なロールが含まれているかどうか確認します。

    ロールを付与する

    1. Google Cloud コンソールで、[IAM] ページに移動します。

      IAM に移動
    2. プロジェクトを選択します。
    3. [ アクセスを許可] をクリックします。
    4. [新しいプリンシパル] フィールドに、ユーザー ID を入力します。 これは通常、Google アカウントのメールアドレスです。

    5. [ロールを選択] をクリックし、ロールを検索します。
    6. 追加のロールを付与するには、 [別のロールを追加] をクリックして各ロールを追加します。
    7. [保存] をクリックします。
  • 16 個の TPU Trillium(v6e)チップに十分な割り当てがあることを確認します。このチュートリアルでは、16 個のチップとオンデマンド インスタンスを必要とするノードプール構成を使用します。
  • Docker リポジトリがあることを確認します。ない場合は、Artifact Registry に標準リポジトリを作成します。

環境を準備する

このチュートリアルでは、Cloud Shell を使用して Google Cloudでホストされているリソースを管理します。Cloud Shell には、このチュートリアルに必要な kubectlGoogle Cloud CLI などのソフトウェアがプリインストールされています。

Cloud Shell を使用して環境を設定するには、次の操作を行います。

  1. Google Cloud コンソールで Cloud Shell セッションを起動し、Cloud Shell 有効化アイコンCloud Shell をアクティブにする)をクリックします。このアクションにより、 Google Cloud コンソールの下部ペインでセッションが起動します。

  2. デフォルトの環境変数を設定します。

    gcloud config set project PROJECT_ID
    gcloud config set billing/quota_project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=CONTROL_PLANE_LOCATION
    export ZONE=ZONE
    export GCS_BUCKET_NAME=BUCKET_NAME
    

    次の値を置き換えます。

    • PROJECT_ID: 実際の Google Cloud プロジェクト ID
    • CLUSTER_NAME: GKE クラスタの名前。
    • CONTROL_PLANE_LOCATION: GKE クラスタと TPU ノードが配置されている Compute Engine リージョン。リージョンには、TPU Trillium(v6e)マシンタイプを使用できるゾーンが含まれている必要があります。
    • ZONE: 選択した CONTROL_PLANE_LOCATION リージョン内の、TPU Trillium(v6e)マシンタイプを使用できるゾーン。TPU Trillium(v6e)TPU を使用できるゾーンを一覧表示するには、次のコマンドを実行します。

        gcloud compute accelerator-types list --filter="name~ct6e" --format="value(zone)"
      
    • BUCKET_NAME: トレーニング データを含む Cloud Storage バケットの名前。

  3. サンプル リポジトリのクローンを作成します。

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    cd kubernetes-engine-samples
    
  4. 作業ディレクトリに移動します。

    cd ai-ml/llm-training-jax-tpu-gemma3
    

Google Cloud リソースを作成して構成する

このセクションでは、 Google Cloud リソースを作成して構成します。

GKE クラスタを作成する

GKE Autopilot クラスタまたは GKE Standard クラスタの TPU で LLM をファインチューニングできます。フルマネージドの Kubernetes エクスペリエンスを実現するには、Autopilot クラスタを使用することをおすすめします。ワークロードに最適な GKE の運用モードを選択するには、GKE の運用モードを選択するをご覧ください。

Autopilot

Workload Identity Federation for GKE を使用し、Cloud Storage FUSE が有効になっている GKE Autopilot クラスタを作成します。

gcloud container clusters create-auto ${CLUSTER_NAME} \
    --location=${REGION}

クラスタの作成には数分かかることもあります。

標準

  1. Workload Identity Federation for GKE を使用し、Cloud Storage FUSE が有効になっているリージョン GKE Standard クラスタを作成します。

    gcloud container clusters create ${CLUSTER_NAME} \
        --enable-ip-alias \
        --addons GcsFuseCsiDriver \
        --machine-type=n2-standard-4 \
        --num-nodes=2 \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --location=${REGION}
    

    クラスタの作成には数分かかることもあります。

  2. 単一ホストのノードプールを作成します。

    gcloud container node-pools create jax-tpu-nodepool \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-1t \
        --num-nodes=1 \
        --location=${REGION} \
        --node-locations=${ZONE} \
        --workload-metadata=GKE_METADATA
    

GKE は、1x1 トポロジと 1 つのノードを持つ TPU Trillium ノードプールを作成します。--workload-metadata=GKE_METADATA フラグにより、GKE メタデータ サーバーを使用するようにノードプールが構成されます。

JobSet をインストールする

  1. クラスタと通信を行うように kubectl を構成します。

    gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}
    
  2. JobSet の最新リリース バージョンをインストールします。

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    JOBSET_VERSION は、JobSet の最新リリース バージョンに置き換えます。例: v0.11.0

  3. JobSet のインストールを確認します。

    kubectl get pods -n jobset-system
    

    出力は次のようになります。

    NAME                                         READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-6c56668494-l4dhc   1/1     Running   0          4m45s
    

    JobSet がリソースを待機している場合は、ノードを追加する必要がある場合があります。

Cloud Storage FUSE を構成する

LLM をファインチューニングするには、トレーニング データを提供する必要があります。このチュートリアルでは、Hugging Face の TinyStories データセットを使用します。このデータセットには、GPT-3.5 と GPT-4 によって合成的に生成された短編小説が含まれています。これらの短編小説では、限られた語彙が使用されています。

このセクションでは、Cloud Storage FUSE が Cloud Storage バケットからデータを読み取るように構成する手順について説明します。

  1. データセットをダウンロードします

    wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
    
  2. データを新しい Cloud Storage バケットにアップロードします。

    gcloud storage buckets create gs://${GCS_BUCKET_NAME} \
        --location=${REGION} \
        --enable-hierarchical-namespace \
        --uniform-bucket-level-access
    gcloud storage cp TinyStories-train.txt gs://${GCS_BUCKET_NAME}
    
  3. ワークロードが Cloud Storage FUSE を介してデータを読み取れるようにするには、Kubernetes サービス アカウント(KSA)を作成し、必要な権限を追加します。permissionsetup.sh スクリプトを実行します。

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    #!/bin/bash
    
    # --- Configuration Variables ---
    # Kubernetes Service Account details
    export KSA_NAME="jaxserviceaccout"
    export NAMESPACE="default"
    
    # Google Cloud IAM Service Account details
    export GSA_NAME="<GSA_NAME>"
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    export  GSA_DESCRIPTION="GKE Service Account to read GCS bucket for ${KSA_NAME}"
    
    # GCS Bucket details
    export GCS_BUCKET_NAME="<GCS_BUCKET_NAME>" # <--- IMPORTANT: Update this to your bucket name
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    export WI_MEMBER="serviceAccount:${PROJECT_ID}.svc.id.goog[${NAMESPACE}/${KSA_NAME}]"
    
    # --- Check if PROJECT_ID is set ---
    if [ -z "${PROJECT_ID}" ]; then
      echo "Error: PROJECT_ID is not set. Please set it using 'gcloud config set project YOUR_PROJECT_ID'"
      exit 1
    fi
    
    echo "--- Configuration ---"
    echo "KSA_NAME:      ${KSA_NAME}"
    echo "NAMESPACE:     ${NAMESPACE}"
    echo "GSA_NAME:      ${GSA_NAME}"
    echo "PROJECT_ID:    ${PROJECT_ID}"
    echo "GSA_EMAIL:     ${GSA_EMAIL}"
    echo "GCS_BUCKET_NAME:   ${GCS_BUCKET_NAME}"
    echo "WI_MEMBER:     ${WI_MEMBER}"
    echo "--------------------"
    read -p "Press enter to continue..."
    
    # --- Command Execution ---
    
    echo "[1/5] Creating Google Cloud IAM Service Account (GSA): ${GSA_NAME}"
    gcloud iam service-accounts create "${GSA_NAME}" \
        --project="${PROJECT_ID}" \
        --description="${GSA_DESCRIPTION}" \
        --display-name="${GSA_NAME}"
    
    echo "[2/5] Granting GSA '${GSA_EMAIL}' read access (roles/storage.objectViewer) to bucket 'gs://${GCS_BUCKET_NAME}'"
    gcloud storage buckets add-iam-policy-binding "gs://${GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectViewer" \
        --project="${PROJECT_ID}"
    
    echo "[3/5] Creating Kubernetes Service Account (KSA): ${KSA_NAME} in namespace ${NAMESPACE}"
    kubectl create serviceaccount "${KSA_NAME}" --namespace "${NAMESPACE}"
    
    echo "[4/5] Allowing KSA to impersonate GSA (Workload Identity Binding): ${GSA_EMAIL}"
    gcloud iam service-accounts add-iam-policy-binding "${GSA_EMAIL}" \
        --role roles/iam.workloadIdentityUser \
        --member "${WI_MEMBER}" \
        --project="${PROJECT_ID}"
    
    echo "[5/5] Annotating KSA '${KSA_NAME}' to link with GSA '${GSA_EMAIL}'"
    kubectl annotate serviceaccount "${KSA_NAME}" \
        --namespace "${NAMESPACE}" \
        iam.gke.io/gcp-service-account="${GSA_EMAIL}"
    
    echo "--- Setup Complete ---"
    echo "Pods in namespace '${NAMESPACE}' using serviceAccount '${KSA_NAME}' can now authenticate as '${GSA_EMAIL}' and have read access to 'gs://${GCS_BUCKET_NAME}'."
    

    このスクリプトを実行すると、Google Cloud プロジェクトと GKE クラスタに次のリソースが構成されます。

    • プロジェクトに gcs-fuse-sa という名前の新しい IAM サービス アカウントが作成されます。
    • 作成された Google Cloud サービス アカウント(GSA)(gcs-fuse-sa)には、${GCS_BUCKET_NAME} で指定された Cloud Storage バケットに対する roles/storage.objectViewer ロールが付与されます。この権限により、GSA はバケットからオブジェクトを読み取ることができます。
    • GKE クラスタ内の default Namespace に、jaxserviceaccount という名前の新しい KSA が作成されます。
    • GSA の IAM ポリシーが更新され、KSA に roles/iam.workloadIdentityUser ロールが付与されます。この権限により、KSA は GSA を偽装できます。
    • KSA にアノテーションを付けて、GSA にリンクします。このアノテーションは、Workload Identity を使用して KSA がどの GSA を偽装するかを GKE に伝えます。

      jaxserviceaccount サービス アカウントを使用する GKE クラスタの default Namespace で実行されている Pod は、gcs-fuse-sa GSA として認証できるようになります。これらの Pod は gs://${GCS_BUCKET_NAME} バケットに保存されているオブジェクトに対する読み取りアクセス権を持ちます。これは、ファインチューニング ジョブが Cloud Storage FUSE を使用してデータセットにアクセスするために不可欠です。

ファインチューニング スクリプトを作成する

このセクションでは、Gemma 3 モデルでファインチューニング オペレーションを実行するトレーニング スクリプトについて説明します。このスクリプトでは Gemma3Tokenizer を使用します。

次の Gemma3LLMTrain.py ファインチューニング スクリプトを確認します。

# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import grain.python as pygrain
import jax
import jax.numpy as jnp
import optax
import pandas as pd
import time
import argparse

from dataclasses import dataclass
from functools import partial
from gemma import gm
from flax.training import train_state
from jax.sharding import Mesh, PartitionSpec, NamedSharding

jax.distributed.initialize()
print("Global device count:", jax.device_count())
print("jax version:", jax.__version__)

tokenizer = gm.text.Gemma3Tokenizer()
num_epochs = 1
learning_rate = 2e-5

@dataclass
class TextDataset:
    data: list
    maxlen: int
    tokenizer: gm.text.Gemma3Tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        encoding = self.tokenizer.encode(self.data[idx])[:self.maxlen]  # Tokenize and truncate
        return encoding + [0] * (self.maxlen - len(encoding))  # Pad to maxlen

def load_and_preprocess_data(file_path, batch_size, maxlen, datacount, tokenizer):

    with open(file_path, 'r') as f:
      text = f.read()

    stories = text.split('<|endoftext|>')
    stories = [story for story in stories if story.strip()][:datacount]
    df = pd.DataFrame({'text': stories})
    data = df['text'].dropna().tolist()
    dataset = TextDataset(data, maxlen, tokenizer)

    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=False,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )

    dataloader = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )

    return dataloader

def generate_text(model, params, tokenizer, prompt):
    sampler = gm.text.Sampler(
        model=model,
        params=params,
        tokenizer=tokenizer,
    )
    print("Generating response for: " + prompt)
    out = sampler.sample(prompt, max_new_tokens=32)
    print("Reponse: \n" + out + "\n")
    return out

prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))

@partial(jax.jit, donate_argnums=(0,))
def train_step(state, batch):
    """Performs one supervised fine-tuning step."""

    def loss_fn(params):
        # Run the forward pass. The model returns logits.
        logits = state.apply_fn({'params': params}, batch[0]).logits

        # Calculate the cross-entropy loss.
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch[1]
        ).mean()

        return loss

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)

    # Update the model state
    state = state.apply_gradients(grads=grads)

    metrics = {'loss': loss}
    return state, metrics

def train_model(state, text_dl, num_epochs, sharding):
    batchCount = 0
    start_time = time.time()
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch in text_dl:
            if len(batch) % len(jax.devices()) != 0:
              continue  # skip the remaining elements
            input_batch = jnp.array(jnp.array(batch).T)
            target_batch = prep_target_batch(input_batch)
            state, metrics = train_step(state, jax.device_put((input_batch, target_batch), sharding))

            if batchCount % 10 == 0:
                print(f"Loss after batch {batchCount}: {metrics['loss']}")
            batchCount += 1

    end_time = time.time()
    print(f"Completed training model. Total time for training {end_time - start_time} seconds \n")
    return state

def run_training(maxlen, batch_size, datacount):
    print(f"Batch size: {batch_size}, Max length: {maxlen}, Data count: {datacount}")
    #Load the training data
    tiny_stories_dl = load_and_preprocess_data('/data/TinyStories-train.txt', batch_size, maxlen, datacount, tokenizer)
    # Get the Gemma3 model
    model = gm.nn.Gemma3_270M()
    # Load the pretrained parameters
    params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_270M_PT)
    # Create an optimizer
    optimizer = optax.adamw(learning_rate=learning_rate)
    # Define sharding for data parallel training
    mesh = Mesh(jax.devices(), ('batch',))
    sharding = NamedSharding(mesh, PartitionSpec('batch', None))

    # Testing out current state of the model
    test_prompt = "Once upon a time, there was a girl named Amy."
    generate_text(model, params, tokenizer, test_prompt)

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer
    )

    # Perform post training
    print("Start training model")
    state = train_model(state, tiny_stories_dl, num_epochs, sharding)

    # Final text generation
    generate_text(model, state.params, tokenizer, test_prompt)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Gemma model with custom parameters.')
    parser.add_argument('--maxlen', type=int, default=256, help='Maximum sequence length')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--datacount', type=int, default=296000, help='Number of data samples to use')
    args = parser.parse_args()

    run_training(maxlen=args.maxlen, batch_size=args.batch_size, datacount=args.datacount)

このスクリプトでは、次のようになります。

  • Gemma3Tokenizer は、モデルが処理できるトークンにテキストデータを変換します。
  • load_and_preprocess_data 関数は、ファイルからトレーニング データを読み取り、個々のストーリーに分割し、トークナイザーを使用してテキストをトークンのパディングされたシーケンスに変換します。
  • generate_text 関数は、モデル、そのパラメータ、プロンプトを受け取ってテキストを生成します。
  • train_step 関数は、フォワード パス、損失計算(クロスエントロピーを使用)、グラデーション計算、パラメータ更新を含むトレーニングの単一のイテレーションを定義します。
  • train_model 関数は、指定されたエポック数だけデータセットを反復処理し、バッチごとに train_step 関数を呼び出します。
  • run_training 関数は、データ読み込み、Gemma 3 モデル(Gemma3_270M)とオプティマイザーの初期化、事前トレーニング済みパラメータの読み込み、並列処理用のデータ シャーディングの設定、テスト生成の実行、トレーニング ループの実行、ファインチューニングの効果を示す最終的なテキスト生成の実行というプロセス全体を調整します。
  • このスクリプトは、argparse ライブラリを使用して、maxlenbatch_sizedatacount パラメータのコマンドライン引数を受け入れます。

ファインチューニング スクリプトを確認したので、GKE で実行できるようにコンテナ化します。

ファインチューニング スクリプトをコンテナ化する

GKE クラスタでファインチューニング スクリプトを実行する前に、コンテナ化する必要があります。このチュートリアルでは、ベースイメージとして JAX AI イメージを使用します。

  1. Gemma3LLMTrain.py ファイルと同じディレクトリにある Dockerfile を開きます。

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    この Dockerfile は、必要な依存関係をインストールし、Gemma3LLMTrain.py ファイルをコンテナにコピーします。

  2. Docker イメージをビルドしてイメージ リポジトリに push します。

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    REPOSITORY_NAME は、Artifact Registry リポジトリの名前に置き換えます。

  3. サービス アカウントにロール バインディングを追加します。

    export PROJECT_NUMBER=$(gcloud projects describe $PROJECT_ID --format 'get(projectNumber)')
    gcloud artifacts repositories add-iam-policy-binding ${REPOSITORY} \
        --project=${PROJECT_ID} \
        --location=${REGION} \
        --member="serviceAccount:${PROJECT_NUMBER}-compute@developer.gserviceaccount.com" \
        --role="roles/artifactregistry.reader"
    

リポジトリにイメージが保存されたので、ファインチューニング ジョブを GKE クラスタにデプロイできます。

LLM ファインチューニング ジョブをデプロイする

このセクションでは、LLM ファインチューニング ジョブを GKE クラスタにデプロイする方法について説明します。

  1. training_singlehost.yaml マニフェストを開きます。

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "355120"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. 次のようにマニフェストを適用します。

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

GKE は、TPU Trillium(v6e)ノードで Pod を起動する Job を作成します。この Pod は、Python ファインチューニング スクリプトを実行します。このスクリプトは、Cloud Storage FUSE を使用して、/data パスにマウントされた指定の Cloud Storage バケットからファインチューニング データにアクセスします。その後、スクリプトは Gemma モデルをファインチューニングします。

トレーニング Job をモニタリングする

このセクションでは、ファインチューニング ジョブの進行状況とパフォーマンスをモニタリングします。

ファインチューニングの進行状況を確認する

  1. Pod を一覧表示します。

    # Find the Pods
    kubectl get pods
    
  2. ログ出力を追跡します。

    kubectl logs -f pods/POD_NAME
    

    POD_NAME は、Pod の名前に置き換えます。

    出力は次のようになります。

    Global device count: 1
    Batch size: 128, Max length: 256, Data count: 96000
    I1028 00:12:55.925999 1387 google_auth_provider.cc:181] Running on GCE, using service account ...
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    Amy lived in a small house. The house was in a big field. Amy liked to play in the big field. She
    Start training model
    Loss after batch 0: 10.25
    Loss after batch 10: 4.3125
    .
    .
    .
    Loss after batch 740: 1.41406
    Completed training model. Total time for training 294.6791355609894 seconds
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    She loved to play with her toys. One day, Amy's mom told her that she had to go to the store to
    
  3. 出力を分析します。

    • Global device count: 1 の線は、使用されている TPU コアを示します。
    • このファインチューニング実行の前に、モデルは事前トレーニング済みのチェックポイントから読み込むため、妥当なテキストを生成します。
    • ファインチューニング後に生成された出力は、短編小説の冒頭に似たものになり、モデルが新しいデータセットから学習していることがわかります。
    • データセット全体でファインチューニングを行うと、さらに洗練された出力が得られます。

指標をモニタリングする

TPU と CPU の指標を確認して、ファインチューニング ジョブのパフォーマンスを確認します。クラスタのオブザーバビリティ指標を表示するには、クラスタとワークロードのオブザーバビリティ指標を表示するの手順に沿って操作します。

代替のファインチューニング構成

このセクションでは、ファインチューニング ワークロードの代替構成について説明します。

モデルの選択

このチュートリアルでは、単一ホストの TPU Trillium(v6e)ノードプールに収まる小規模なモデルである Gemma3_270M モデルを使用しました。ファインチューニングに多くのメモリとコンピューティングを必要とする大規模なモデルの場合は、マルチホストまたはマルチスライス ノードプール構成を使用できます。

使用可能なモデルの一覧については、Gemma のドキュメントをご覧ください。

ノードプールの構成

このチュートリアルでは、単一ホストのノードプールを使用しました。必要に応じて、マルチホスト TPU スライス ノードプールまたはマルチスライス ノードプールを作成することもできます。

次のタブは、マルチホスト ノードプールとマルチスライス ノードプールを作成する方法を示しています。

マルチホスト

  1. Cloud Shell で、次のコマンドを実行します。

    gcloud container node-pools create jax-tpu-multihost1 \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-4t \
        --num-nodes=2 \
        --tpu-topology=2x4 \
        --location=${REGION} \
        --node-locations=${ZONE}
    

    GKE は、2x4 トポロジと 2 つのノードを持つ TPU Trillium ノードプールを作成します。

  2. training_multihost_jobset.yaml ジョブ定義を開きます。

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multihost
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                    cloud.google.com/gke-nodepool: jax-tpu-multihost1
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI} 
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
    
  3. ファインチューニング Job をデプロイします。

    envsubst < training_multihost_jobset.yaml | kubectl apply -f -
    

マルチスライス

  1. Cloud Shell で、次のコマンドを実行します。

    gcloud container node-pools create jax-tpu-multihost1 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    
    gcloud container node-pools create jax-tpu-multihost2 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    

    GKE は 2 つの TPU Trillium ノードプールを作成します。各ノードプールには、2x4 トポロジと 2 つのノードがあります。

  2. training_multislice_jobset.yaml ジョブ定義を開きます。

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multislice
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI}
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:100"
    
  3. ファインチューニング Job をデプロイします。

    envsubst < training_multislice_jobset.yaml | kubectl apply -f -
    

パフォーマンスの分析と最適化

ML ファインチューニングのパフォーマンスを分析して最適化するには、XProf を使用します。XProf は、JAX、TensorFlow、PyTorch/XLA で構築された ML ワークロードをプロファイリングして検査する一連のツールです。実行トレース、メモリ使用量、その他のデータを表示することで、XProf を使用してモデルとトレーニング設定を微調整し、効率を高めてトレーニングを高速化できます。

XProf を使用してファインチューニング ワークロードのパフォーマンスを分析するには、このセクションで次の手順を行います。

  • xprof パッケージをインストールします。XProf サーバーを起動するようにトレーニング スクリプトを変更します。
  • XProf ログのボリューム マウントを含めるように Kubernetes Job マニフェストを変更します。
  • Cloud Storage バケットに XProf ログを書き込む権限をサービス アカウントに付与します。
  • Pod 内で XProf を実行し、XProf ダッシュボードにアクセスするようにポート転送を設定します。

XProf パッケージをインストールする

  1. XProf サンプルが含まれているディレクトリに移動します。

      cd ai-ml/llm-training-jax-tpu-gemma3/xprof-enabled
    
  2. Docker イメージをビルドしてイメージ リポジトリに push します。

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training-xp"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    REPOSITORY_NAME は、Artifact Registry リポジトリの名前に置き換えます。

  3. Dockerfile スクリプトを実行します。

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain equinox
    RUN pip install xprof
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    この Dockerfile は XProf の依存関係をインストールします。

ファインチューニング スクリプトをコンテナにコピーする

このセクションでは、XProf ログに必要なボリューム マウントを含む Kubernetes Job マニフェストを作成して適用します。

  1. training_singlehost.yaml ジョブ定義を開きます。

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "851200"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
            - name: gcs-fuse-csi-ephemeral2
              mountPath: /xprof
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
          - name: gcs-fuse-csi-ephemeral2
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${XPROF_GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. 次のようにマニフェストを適用します。

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

サービス アカウントに XProf ログを書き込む権限を付与する

  1. サービス アカウントで書き込みと読み取りを有効にするには、"roles/storage.objectUser" ロールを追加します。

    export GSA_NAME="GSA_NAME" # Same as used in initial setup
    
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    
    # Cloud Storage Bucket details
    export XPROF_GCS_BUCKET_NAME="XPROF_GCS_BUCKET_NAME"
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    
    gcloud storage buckets add-iam-policy-binding "gs://${XPROF_GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectUser" \
        --project="${PROJECT_ID}"
    

    次のように置き換えます。

    • GSA_NAME: ロールを付与する Google サービス アカウントの名前。
    • XPROF_GCS_BUCKET_NAME: ロールを付与するバケットの名前。
  2. Pod 内で XProf を実行します。

    kubectl exec POD_NAME -c training-container -it -- bash # exec into the container
    xprof --port 9001 --logdir /xprof # start xprof
    

    POD_NAME は、Pod の名前に置き換えます。

XProf ダッシュボードにアクセスする

  1. Pod の XProf サーバーへのポート転送を設定します。

    kubectl port-forward POD_NAME 9001:9001
    
  2. ブラウザのアドレスバーに次のように入力します。

    http://localhost:9001/
    

    XProf トレースビューアが開きます。

  3. TensorBoard ウィンドウで [Capture profile] をクリックします。

  4. [Profile Service URL(s) or TPU name] フィールドに「localhost:9002」と入力します。

  5. 詳細をキャプチャするには、[ホスト トレース(TraceMe)レベル] で [詳細] を選択し、Python トレース ロギングを有効にします。

  6. ダッシュボードを表示するには、[キャプチャ] をクリックします。

    TensorBoard はプロファイルをキャプチャし、トレーニング スクリプトのパフォーマンスを分析できます。グラフには、TPU と CPU の両方のパフォーマンス プロファイルの実行タイムラインが表示されます。

パフォーマンス マトリックス グラフを示す XProf トレース ビューアの例

トレーニング ワークロードのパフォーマンスを分析するためのその他のプロファイリング オプションについては、計算のプロファイリングに関する JAX ドキュメントをご覧ください。

本番環境でのファインチューニング

このチュートリアルでは、分散環境で JAX ベースのトレーニングをテストする方法について説明しました。本番環境で最適化された LLM ファインチューニングを行うには、Maxtext ライブラリを使用します。拡散モデルに関心がある場合は、Maxdiffusion 実装を使用します。

本番環境で長時間実行されるトレーニング ワークロードまたはファインチューニング ワークロードの場合は、障害発生時の進行状況の損失を最小限に抑えるようにワークロードのチェックポイント処理を設定します。多層チェックポイント処理の設定の詳細については、多層チェックポイント処理を使用して GKE で大規模な ML モデルをトレーニングするをご覧ください。

クリーンアップ

このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

個々のリソースの削除

このチュートリアルで使用したリソースについて、 Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して次のコマンドを実行して個々のリソースを削除します。

  1. このチュートリアルで作成したリソースを削除します。

    gcloud container clusters delete ${CLUSTER_NAME} --location=${REGION}
    gcloud storage rm --recursive gs://${GCS_BUCKET_NAME}
    gcloud artifacts docker images delete ${IMAGE_URI} --delete-tags
    
  2. XProf によって生成されたデータが不要な場合は、XProf で使用されている Cloud Storage バケットを削除します。

    gcloud storage rm --recursive gs://${XPROF_GCS_BUCKET_NAME}
    

次のステップ