Affiner un LLM à l'aide de TPU sur GKE avec JAX

Ce tutoriel explique comment affiner un grand modèle de langage (LLM) à l'aide de Tensor Processing Units (TPU) sur Google Kubernetes Engine (GKE) avec JAX. L'affinage vous permet d'adapter un modèle de fondation tel que Gemma 3 à un domaine ou une tâche spécifique. Ce processus améliore la précision et la justesse du modèle en mettant à jour ses paramètres avec votre propre ensemble de données spécialisé.

Ce guide est un bon point de départ si vous avez besoin du contrôle précis, de la personnalisation, de l'évolutivité, de la résilience, de la portabilité et de la rentabilité des services Kubernetes gérés lors de l'affinage de vos charges de travail d'IA/de ML.

Arrière-plan

En utilisant des TPU sur GKE avec Jax pour affiner un LLM, vous pouvez créer une solution d'affinage robuste et prête pour la production avec tous les avantages de Kubernetes géré.

Gemma

Gemma est un ensemble de modèles multimodaux d'IA/ML générative, légers et disponibles publiquement, publiés sous licence ouverte. Ces modèles d'IA sont disponibles pour s'exécuter dans vos applications, votre matériel, vos appareils mobiles ou vos services hébergés. Gemma 3 introduit la multimodalité et accepte les entrées de langage visuel et les sorties de texte. Il gère les fenêtres de contexte jusqu'à 128 000 jetons et est compatible avec plus de 140 langues. Gemma 3 offre également des capacités améliorées en mathématiques, en raisonnement et en discussion, y compris des sorties structurées et l'appel de fonctions.

Vous pouvez utiliser les modèles Gemma pour la génération de texte, mais vous pouvez également les ajuster pour des tâches spécialisées.

Pour en savoir plus, consultez la documentation Gemma.

TPU

Les TPU sont des circuits intégrés propres aux applications (ASIC) développés spécifiquement par Google pour accélérer le machine learning et les modèles d'IA créés à l'aide de frameworks tels que TensorFlow, PyTorch et JAX.

Avant d'utiliser des TPU dans GKE, nous vous recommandons de suivre le parcours de formation suivant :

  1. Découvrez la disponibilité actuelle des versions de TPU avec l'architecture système de Cloud TPU.
  2. En savoir plus sur les TPU dans GKE

JAX

JAX est un framework de machine learning hautes performances conçu pour être utilisé avec des TPU et des GPU. JAX fournit une API pour créer et entraîner des modèles de machine learning.

Pour en savoir plus, consultez le dépôt JAX.

Objectifs

Ce tutoriel couvre les étapes suivantes :

  1. Créer un cluster GKE Autopilot ou standard avec la topologie TPU recommandée en fonction des caractéristiques du modèle. Dans ce tutoriel, vous allez affiner les pools de nœuds à hôte unique.
  2. Ajoutez des données à un bucket Cloud Storage et installez-les dans le conteneur via Cloud Storage FUSE.
  3. Déployez le job d'affinage du LLM sur GKE.
  4. Surveillez le job d'affinage et consultez les journaux.

Avant de commencer

  • Connectez-vous à votre compte Google Cloud . Si vous débutez sur Google Cloud, créez un compte pour évaluer les performances de nos produits en conditions réelles. Les nouveaux clients bénéficient également de 300 $ de crédits sans frais pour exécuter, tester et déployer des charges de travail.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • Assurez-vous de disposer des rôles suivants sur le projet : roles/container.admin,roles/iam.serviceAccountAdmin,roles/storage.admin

    Vérifier les rôles

    1. Dans la console Google Cloud , accédez à la page IAM.

      Accéder à IAM
    2. Sélectionnez le projet.
    3. Dans la colonne Compte principal, recherchez toutes les lignes qui vous identifient ou identifient un groupe dont vous faites partie. Pour savoir à quels groupes vous appartenez, contactez votre administrateur.

    4. Pour toutes les lignes qui vous spécifient ou vous incluent, consultez la colonne Rôle pour vous assurer que la liste inclut les rôles requis.

    Attribuer les rôles

    1. Dans la console Google Cloud , accédez à la page IAM.

      Accéder à IAM
    2. Sélectionnez le projet.
    3. Cliquez sur  Accorder l'accès.
    4. Dans le champ Nouveaux comptes principaux, saisissez votre identifiant utilisateur. Il s'agit généralement de l'adresse e-mail d'un compte Google.

    5. Cliquez sur Sélectionner un rôle, puis recherchez le rôle.
    6. Pour attribuer des rôles supplémentaires, cliquez sur  Ajouter un autre rôle et ajoutez tous les rôles supplémentaires.
    7. Cliquez sur Enregistrer.
  • Assurez-vous de disposer d'un quota suffisant pour 16 puces TPU Trillium (v6e). Dans ce tutoriel, vous utilisez une configuration de pool de nœuds qui nécessite 16 puces et des instances à la demande.
  • Assurez-vous de disposer d'un dépôt Docker. Si vous n'en avez pas, créez un dépôt standard dans Artifact Registry.

Préparer l'environnement

Dans ce tutoriel, vous utilisez Cloud Shell pour gérer les ressources hébergées sur Google Cloud. Cloud Shell est préinstallé avec les logiciels dont vous avez besoin pour ce tutoriel, y compris kubectl et Google Cloud CLI.

Pour configurer votre environnement avec Cloud Shell, procédez comme suit :

  1. Dans la console Google Cloud , lancez une session Cloud Shell et cliquez sur Icône d'activation Cloud Shell Activer Cloud Shell. Une session s'ouvre dans le volet inférieur de la console Google Cloud .

  2. Définissez les variables d'environnement par défaut :

    gcloud config set project PROJECT_ID
    gcloud config set billing/quota_project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=CONTROL_PLANE_LOCATION
    export ZONE=ZONE
    export GCS_BUCKET_NAME=BUCKET_NAME
    

    Remplacez les valeurs suivantes :

    • PROJECT_ID : ID de votre projet Google Cloud .
    • CLUSTER_NAME : nom de votre cluster GKE.
    • CONTROL_PLANE_LOCATION : région Compute Engine où se trouvent votre cluster GKE et vos nœuds TPU. La région doit contenir des zones dans lesquelles les types de machines TPU Trillium (v6e) sont disponibles.
    • ZONE : zone de la région CONTROL_PLANE_LOCATION que vous avez sélectionnée, où les types de machines TPU Trillium (v6e) sont disponibles. Pour lister les zones dans lesquelles les TPU Trillium (v6e) sont disponibles, exécutez la commande suivante :

        gcloud compute accelerator-types list --filter="name~ct6e" --format="value(zone)"
      
    • BUCKET_NAME : nom du bucket Cloud Storage contenant vos données d'entraînement.

  3. Clonez l'exemple de dépôt :

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    cd kubernetes-engine-samples
    
  4. Accédez au répertoire de travail :

    cd ai-ml/llm-training-jax-tpu-gemma3
    

Créer et configurer des ressources Google Cloud

Dans cette section, vous allez créer et configurer des ressources Google Cloud .

Créer un cluster GKE

Vous pouvez affiner un LLM sur des TPU dans un cluster GKE Autopilot ou GKE Standard. Nous vous recommandons d'utiliser un cluster GKE Autopilot pour une expérience Kubernetes entièrement gérée. Pour choisir le mode de fonctionnement GKE le mieux adapté à vos charges de travail, consultez Choisir un mode de fonctionnement GKE.

Autopilot

Créez un cluster GKE Autopilot qui utilise Workload Identity Federation for GKE et dans lequel Cloud Storage FUSE est activé.

gcloud container clusters create-auto ${CLUSTER_NAME} \
    --location=${REGION}

La création du cluster peut prendre plusieurs minutes.

Standard

  1. Créez un cluster GKE Standard régional qui utilise la fédération d'identité de charge de travail pour GKE et sur lequel Cloud Storage FUSE est activé.

    gcloud container clusters create ${CLUSTER_NAME} \
        --enable-ip-alias \
        --addons GcsFuseCsiDriver \
        --machine-type=n2-standard-4 \
        --num-nodes=2 \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --location=${REGION}
    

    La création du cluster peut prendre plusieurs minutes.

  2. Créez un pool de nœuds à hôte unique :

    gcloud container node-pools create jax-tpu-nodepool \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-1t \
        --num-nodes=1 \
        --location=${REGION} \
        --node-locations=${ZONE} \
        --workload-metadata=GKE_METADATA
    

GKE crée un pool de nœuds TPU Trillium avec une topologie 1x1 et un nœud. L'option --workload-metadata=GKE_METADATA configure le pool de nœuds de sorte qu'il utilise le serveur de métadonnées GKE.

Installer JobSet

  1. Configurez kubectl de manière à communiquer avec votre cluster :

    gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}
    
  2. Installez la dernière version de JobSet :

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    Remplacez JOBSET_VERSION par la dernière version publiée de JobSet. Par exemple, v0.11.0.

  3. Vérifiez l'installation de JobSet :

    kubectl get pods -n jobset-system
    

    Le résultat ressemble à ce qui suit :

    NAME                                         READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-6c56668494-l4dhc   1/1     Running   0          4m45s
    

    Vous devrez peut-être ajouter des nœuds si JobSet attend des ressources.

Configurer Cloud Storage FUSE

Pour affiner le LLM, vous devez fournir des données d'entraînement. Dans ce tutoriel, vous allez utiliser l'ensemble de données TinyStories de Hugging Face. Cet ensemble de données contient des nouvelles générées de manière synthétique par GPT-3.5 et GPT-4, qui utilisent un vocabulaire limité.

Cette section explique comment configurer Cloud Storage FUSE pour lire les données d'un bucket Cloud Storage.

  1. Téléchargez l'ensemble de données :

    wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
    
  2. Importez les données dans un nouveau bucket Cloud Storage :

    gcloud storage buckets create gs://${GCS_BUCKET_NAME} \
        --location=${REGION} \
        --enable-hierarchical-namespace \
        --uniform-bucket-level-access
    gcloud storage cp TinyStories-train.txt gs://${GCS_BUCKET_NAME}
    
  3. Pour permettre à votre charge de travail de lire des données via Cloud Storage FUSE, créez un compte de service Kubernetes (KSA) et ajoutez les autorisations requises. Exécutez le script permissionsetup.sh :

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    #!/bin/bash
    
    # --- Configuration Variables ---
    # Kubernetes Service Account details
    export KSA_NAME="jaxserviceaccout"
    export NAMESPACE="default"
    
    # Google Cloud IAM Service Account details
    export GSA_NAME="<GSA_NAME>"
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    export  GSA_DESCRIPTION="GKE Service Account to read GCS bucket for ${KSA_NAME}"
    
    # GCS Bucket details
    export GCS_BUCKET_NAME="<GCS_BUCKET_NAME>" # <--- IMPORTANT: Update this to your bucket name
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    export WI_MEMBER="serviceAccount:${PROJECT_ID}.svc.id.goog[${NAMESPACE}/${KSA_NAME}]"
    
    # --- Check if PROJECT_ID is set ---
    if [ -z "${PROJECT_ID}" ]; then
      echo "Error: PROJECT_ID is not set. Please set it using 'gcloud config set project YOUR_PROJECT_ID'"
      exit 1
    fi
    
    echo "--- Configuration ---"
    echo "KSA_NAME:      ${KSA_NAME}"
    echo "NAMESPACE:     ${NAMESPACE}"
    echo "GSA_NAME:      ${GSA_NAME}"
    echo "PROJECT_ID:    ${PROJECT_ID}"
    echo "GSA_EMAIL:     ${GSA_EMAIL}"
    echo "GCS_BUCKET_NAME:   ${GCS_BUCKET_NAME}"
    echo "WI_MEMBER:     ${WI_MEMBER}"
    echo "--------------------"
    read -p "Press enter to continue..."
    
    # --- Command Execution ---
    
    echo "[1/5] Creating Google Cloud IAM Service Account (GSA): ${GSA_NAME}"
    gcloud iam service-accounts create "${GSA_NAME}" \
        --project="${PROJECT_ID}" \
        --description="${GSA_DESCRIPTION}" \
        --display-name="${GSA_NAME}"
    
    echo "[2/5] Granting GSA '${GSA_EMAIL}' read access (roles/storage.objectViewer) to bucket 'gs://${GCS_BUCKET_NAME}'"
    gcloud storage buckets add-iam-policy-binding "gs://${GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectViewer" \
        --project="${PROJECT_ID}"
    
    echo "[3/5] Creating Kubernetes Service Account (KSA): ${KSA_NAME} in namespace ${NAMESPACE}"
    kubectl create serviceaccount "${KSA_NAME}" --namespace "${NAMESPACE}"
    
    echo "[4/5] Allowing KSA to impersonate GSA (Workload Identity Binding): ${GSA_EMAIL}"
    gcloud iam service-accounts add-iam-policy-binding "${GSA_EMAIL}" \
        --role roles/iam.workloadIdentityUser \
        --member "${WI_MEMBER}" \
        --project="${PROJECT_ID}"
    
    echo "[5/5] Annotating KSA '${KSA_NAME}' to link with GSA '${GSA_EMAIL}'"
    kubectl annotate serviceaccount "${KSA_NAME}" \
        --namespace "${NAMESPACE}" \
        iam.gke.io/gcp-service-account="${GSA_EMAIL}"
    
    echo "--- Setup Complete ---"
    echo "Pods in namespace '${NAMESPACE}' using serviceAccount '${KSA_NAME}' can now authenticate as '${GSA_EMAIL}' and have read access to 'gs://${GCS_BUCKET_NAME}'."
    

    Une fois ce script exécuté, les ressources suivantes sont configurées dans votre projetGoogle Cloud et votre cluster GKE :

    • Un compte de service IAM nommé gcs-fuse-sa est créé dans votre projet.
    • Le compte de service Google Cloud (GSA) créé (gcs-fuse-sa) reçoit le rôle roles/storage.objectViewer sur le bucket Cloud Storage spécifié par ${GCS_BUCKET_NAME}. Google Cloud Cette autorisation permet à l'AGS de lire les objets du bucket.
    • Un nouveau compte de service Kubernetes nommé jaxserviceaccount est créé dans l'espace de noms default de votre cluster GKE.
    • La stratégie IAM du compte de service Google est mise à jour pour accorder le rôle roles/iam.workloadIdentityUser au compte de service Kubernetes. Cette autorisation permet au KSA d'emprunter l'identité du GSA.
    • Le KSA est annoté pour être associé au GSA. Cette annotation indique à GKE quel GSA le KSA doit emprunter en utilisant Workload Identity.

      Tout pod s'exécutant dans l'espace de noms default de votre cluster GKE et utilisant le compte de service jaxserviceaccount pourra désormais s'authentifier en tant que compte de service Google gcs-fuse-sa. Ces pods auront un accès en lecture aux objets stockés dans le bucket gs://${GCS_BUCKET_NAME}, ce qui est essentiel pour que le job d'affinage puisse accéder au jeu de données à l'aide de Cloud Storage FUSE.

Créer le script d'affinage

Dans cette section, vous allez explorer le script d'entraînement qui effectue une opération de réglage fin sur un modèle Gemma 3. Ce script utilise Gemma3Tokenizer.

Examinez le script de réglage fin Gemma3LLMTrain.py suivant :

# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import grain.python as pygrain
import jax
import jax.numpy as jnp
import optax
import pandas as pd
import time
import argparse

from dataclasses import dataclass
from functools import partial
from gemma import gm
from flax.training import train_state
from jax.sharding import Mesh, PartitionSpec, NamedSharding

jax.distributed.initialize()
print("Global device count:", jax.device_count())
print("jax version:", jax.__version__)

tokenizer = gm.text.Gemma3Tokenizer()
num_epochs = 1
learning_rate = 2e-5

@dataclass
class TextDataset:
    data: list
    maxlen: int
    tokenizer: gm.text.Gemma3Tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        encoding = self.tokenizer.encode(self.data[idx])[:self.maxlen]  # Tokenize and truncate
        return encoding + [0] * (self.maxlen - len(encoding))  # Pad to maxlen

def load_and_preprocess_data(file_path, batch_size, maxlen, datacount, tokenizer):

    with open(file_path, 'r') as f:
      text = f.read()

    stories = text.split('<|endoftext|>')
    stories = [story for story in stories if story.strip()][:datacount]
    df = pd.DataFrame({'text': stories})
    data = df['text'].dropna().tolist()
    dataset = TextDataset(data, maxlen, tokenizer)

    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=False,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )

    dataloader = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )

    return dataloader

def generate_text(model, params, tokenizer, prompt):
    sampler = gm.text.Sampler(
        model=model,
        params=params,
        tokenizer=tokenizer,
    )
    print("Generating response for: " + prompt)
    out = sampler.sample(prompt, max_new_tokens=32)
    print("Reponse: \n" + out + "\n")
    return out

prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))

@partial(jax.jit, donate_argnums=(0,))
def train_step(state, batch):
    """Performs one supervised fine-tuning step."""

    def loss_fn(params):
        # Run the forward pass. The model returns logits.
        logits = state.apply_fn({'params': params}, batch[0]).logits

        # Calculate the cross-entropy loss.
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch[1]
        ).mean()

        return loss

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)

    # Update the model state
    state = state.apply_gradients(grads=grads)

    metrics = {'loss': loss}
    return state, metrics

def train_model(state, text_dl, num_epochs, sharding):
    batchCount = 0
    start_time = time.time()
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch in text_dl:
            if len(batch) % len(jax.devices()) != 0:
              continue  # skip the remaining elements
            input_batch = jnp.array(jnp.array(batch).T)
            target_batch = prep_target_batch(input_batch)
            state, metrics = train_step(state, jax.device_put((input_batch, target_batch), sharding))

            if batchCount % 10 == 0:
                print(f"Loss after batch {batchCount}: {metrics['loss']}")
            batchCount += 1

    end_time = time.time()
    print(f"Completed training model. Total time for training {end_time - start_time} seconds \n")
    return state

def run_training(maxlen, batch_size, datacount):
    print(f"Batch size: {batch_size}, Max length: {maxlen}, Data count: {datacount}")
    #Load the training data
    tiny_stories_dl = load_and_preprocess_data('/data/TinyStories-train.txt', batch_size, maxlen, datacount, tokenizer)
    # Get the Gemma3 model
    model = gm.nn.Gemma3_270M()
    # Load the pretrained parameters
    params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_270M_PT)
    # Create an optimizer
    optimizer = optax.adamw(learning_rate=learning_rate)
    # Define sharding for data parallel training
    mesh = Mesh(jax.devices(), ('batch',))
    sharding = NamedSharding(mesh, PartitionSpec('batch', None))

    # Testing out current state of the model
    test_prompt = "Once upon a time, there was a girl named Amy."
    generate_text(model, params, tokenizer, test_prompt)

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer
    )

    # Perform post training
    print("Start training model")
    state = train_model(state, tiny_stories_dl, num_epochs, sharding)

    # Final text generation
    generate_text(model, state.params, tokenizer, test_prompt)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Gemma model with custom parameters.')
    parser.add_argument('--maxlen', type=int, default=256, help='Maximum sequence length')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--datacount', type=int, default=296000, help='Number of data samples to use')
    args = parser.parse_args()

    run_training(maxlen=args.maxlen, batch_size=args.batch_size, datacount=args.datacount)

Dans ce script, les éléments suivants s'appliquent :

  • Un Gemma3Tokenizer convertit les données textuelles en jetons que le modèle peut traiter.
  • La fonction load_and_preprocess_data lit les données d'entraînement à partir d'un fichier, les divise en histoires individuelles et utilise le tokenizer pour convertir le texte en séquences de jetons complétées.
  • La fonction generate_text utilise le modèle, ses paramètres et une invite pour générer du texte.
  • La fonction train_step définit une seule itération d'entraînement qui inclut la passe avant, le calcul de la perte (à l'aide de l'entropie croisée), le calcul du gradient et les mises à jour des paramètres.
  • La fonction train_model parcourt l'ensemble de données pour un nombre d'époques spécifié, ce qui appelle la fonction train_step pour chaque lot.
  • La fonction run_training orchestre l'ensemble du processus de chargement des données, d'initialisation du modèle Gemma 3 (Gemma3_270M) et de l'optimiseur, de chargement des paramètres pré-entraînés, de configuration du partitionnement des données pour le traitement parallèle, d'exécution d'une génération de test, d'exécution de la boucle d'entraînement et d'exécution d'une génération de texte finale pour démontrer l'effet du réglage fin.
  • Le script utilise la bibliothèque argparse pour accepter les arguments de ligne de commande pour les paramètres maxlen, batch_size et datacount.

Maintenant que vous avez exploré le script d'affinage, conteneurisez-le pour l'exécuter sur GKE.

Conteneuriser le script d'affinage

Avant d'exécuter le script d'affinage dans un cluster GKE, vous devez le conteneuriser. Ce tutoriel utilise une image JAX AI comme image de base.

  1. Ouvrez le fichier Dockerfile dans le même répertoire que le fichier Gemma3LLMTrain.py :

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    Ce fichier Dockerfile installe les dépendances nécessaires et copie le fichier Gemma3LLMTrain.py dans le conteneur.

  2. Créez l'image Docker et transférez-la vers un dépôt d'images :

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Remplacez REPOSITORY_NAME par le nom de votre dépôt Artifact Registry.

  3. Ajoutez des liaisons de rôle au compte de service :

    export PROJECT_NUMBER=$(gcloud projects describe $PROJECT_ID --format 'get(projectNumber)')
    gcloud artifacts repositories add-iam-policy-binding ${REPOSITORY} \
        --project=${PROJECT_ID} \
        --location=${REGION} \
        --member="serviceAccount:${PROJECT_NUMBER}-compute@developer.gserviceaccount.com" \
        --role="roles/artifactregistry.reader"
    

Une fois l'image dans le dépôt, vous pouvez déployer le job d'affinage dans un cluster GKE.

Déployer le job d'affinage du LLM

Cette section vous explique comment déployer le job de réglage fin du LLM sur votre cluster GKE.

  1. Ouvrez le fichier manifeste training_singlehost.yaml :

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "355120"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Appliquez le fichier manifeste :

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

GKE crée un Job qui lance un pod sur un nœud TPU Trillium (v6e). Ce pod exécute le script Python d'affinage, qui accède aux données d'affinage à partir du bucket Cloud Storage spécifié, monté au chemin d'accès /data à l'aide de Cloud Storage FUSE. Le script affine ensuite le modèle Gemma.

Surveiller le job d'entraînement

Dans cette section, vous allez surveiller la progression du job d'affinage et ses performances.

Afficher la progression de l'affinage

  1. Répertoriez les pods :

    # Find the Pods
    kubectl get pods
    
  2. Suivez la sortie du journal :

    kubectl logs -f pods/POD_NAME
    

    Remplacez POD_NAME par le nom de votre pod.

    Le résultat ressemble à ce qui suit :

    Global device count: 1
    Batch size: 128, Max length: 256, Data count: 96000
    I1028 00:12:55.925999 1387 google_auth_provider.cc:181] Running on GCE, using service account ...
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    Amy lived in a small house. The house was in a big field. Amy liked to play in the big field. She
    Start training model
    Loss after batch 0: 10.25
    Loss after batch 10: 4.3125
    .
    .
    .
    Loss after batch 740: 1.41406
    Completed training model. Total time for training 294.6791355609894 seconds
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    She loved to play with her toys. One day, Amy's mom told her that she had to go to the store to
    
  3. Analysez le résultat :

    • La ligne Global device count: 1 indique les cœurs de TPU utilisés.
    • Le modèle génère un texte raisonnable avant cette exécution de l'affinage, car il se charge à partir d'un point de contrôle pré-entraîné.
    • La sortie générée après l'affinage ressemble davantage au début d'une nouvelle, ce qui indique que le modèle apprend à partir du nouvel ensemble de données.
    • L'affinage sur l'ensemble de données complet devrait produire des résultats encore plus précis.

Observer les métriques

Consultez les performances du job d'affinage en vérifiant les métriques TPU et CPU. Pour afficher les métriques d'observabilité de votre cluster, suivez les étapes décrites dans Afficher les métriques d'observabilité des clusters et des charges de travail.

Autres configurations d'affinage

Cette section décrit d'autres configurations pour votre charge de travail d'affinage.

Sélection du modèle

Ce tutoriel utilisait le modèle Gemma3_270M, qui est un petit modèle qui s'intègre dans un pool de nœuds TPU Trillium (v6e) à hôte unique. Pour les modèles plus volumineux qui nécessitent plus de mémoire et de calcul pour l'affinage, vous pouvez utiliser des configurations de pool de nœuds multihôtes ou multislices.

Pour obtenir la liste complète des modèles disponibles, consultez la documentation Gemma.

Configurations de pool de nœuds

Ce tutoriel utilisait un pool de nœuds à hôte unique. Vous pouvez également créer des pools de nœuds de tranche TPU multi-hôtes ou des pools de nœuds multislices, selon vos besoins.

Les onglets suivants montrent comment créer des pools de nœuds multi-hôtes et multislices :

Multi-hôtes

  1. Dans Cloud Shell, exécutez la commande suivante :

    gcloud container node-pools create jax-tpu-multihost1 \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-4t \
        --num-nodes=2 \
        --tpu-topology=2x4 \
        --location=${REGION} \
        --node-locations=${ZONE}
    

    GKE crée un pool de nœuds TPU Trillium avec une topologie 2x4 et deux nœuds.

  2. Ouvrez la définition du job training_multihost_jobset.yaml :

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multihost
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                    cloud.google.com/gke-nodepool: jax-tpu-multihost1
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI} 
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
    
  3. Déployez le job d'affinage :

    envsubst < training_multihost_jobset.yaml | kubectl apply -f -
    

Multitranches

  1. Dans Cloud Shell, exécutez la commande suivante :

    gcloud container node-pools create jax-tpu-multihost1 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    
    gcloud container node-pools create jax-tpu-multihost2 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    

    GKE crée deux pools de nœuds TPU Trillium. Chaque pool de nœuds possède une topologie 2x4 et deux nœuds.

  2. Ouvrez la définition du job training_multislice_jobset.yaml :

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multislice
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI}
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:100"
    
  3. Déployez le job d'affinage :

    envsubst < training_multislice_jobset.yaml | kubectl apply -f -
    

Analyse et optimisation des performances

Pour analyser et optimiser les performances de votre affinement du machine learning, vous pouvez utiliser XProf. XProf est une suite d'outils qui profile et inspecte les charges de travail de ML créées avec JAX, TensorFlow ou PyTorch/XLA. En affichant les traces d'exécution, l'utilisation de la mémoire et d'autres données, XProf vous permet d'affiner vos modèles et votre configuration d'entraînement pour une meilleure efficacité et un entraînement plus rapide.

Pour analyser les performances de votre charge de travail d'affinage à l'aide de XProf, suivez les étapes décrites dans cette section :

  • Installez le package xprof. Modifiez votre script d'entraînement pour démarrer le serveur XProf.
  • Modifiez le fichier manifeste de votre job Kubernetes pour inclure un montage de volume pour les journaux XProf.
  • Accordez au compte de service les autorisations nécessaires pour écrire les journaux XProf dans un bucket Cloud Storage.
  • Exécutez XProf dans votre pod et configurez le transfert de port pour accéder au tableau de bord XProf.

Installer le package XProf

  1. Accédez au répertoire qui contient les exemples XProf :

      cd ai-ml/llm-training-jax-tpu-gemma3/xprof-enabled
    
  2. Créez l'image Docker et transférez-la vers un dépôt d'images :

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training-xp"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    Remplacez REPOSITORY_NAME par le nom de votre dépôt Artifact Registry.

  3. Exécutez le script Dockerfile :

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain equinox
    RUN pip install xprof
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    Ce fichier Dockerfile installe les dépendances XProf.

Copiez votre script d'affinage dans le conteneur.

Dans cette section, créez et appliquez un fichier manifeste de job Kubernetes qui inclut les montages de volume nécessaires pour les journaux XProf.

  1. Ouvrez la définition du job training_singlehost.yaml :

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "851200"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
            - name: gcs-fuse-csi-ephemeral2
              mountPath: /xprof
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
          - name: gcs-fuse-csi-ephemeral2
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${XPROF_GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. Appliquez le fichier manifeste :

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

Accorder au compte de service les autorisations nécessaires pour écrire les journaux XProf

  1. Pour permettre au compte de service d'écrire et de lire, ajoutez le rôle "roles/storage.objectUser" :

    export GSA_NAME="GSA_NAME" # Same as used in initial setup
    
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    
    # Cloud Storage Bucket details
    export XPROF_GCS_BUCKET_NAME="XPROF_GCS_BUCKET_NAME"
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    
    gcloud storage buckets add-iam-policy-binding "gs://${XPROF_GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectUser" \
        --project="${PROJECT_ID}"
    

    Remplacez les éléments suivants :

    • GSA_NAME : nom du compte de service Google auquel attribuer le rôle.
    • XPROF_GCS_BUCKET_NAME : nom du bucket auquel attribuer le rôle.
  2. Exécutez XProf dans votre pod :

    kubectl exec POD_NAME -c training-container -it -- bash # exec into the container
    xprof --port 9001 --logdir /xprof # start xprof
    

    Remplacez POD_NAME par le nom de votre pod.

Accéder au tableau de bord XProf

  1. Configurez un transfert de port vers le serveur XProf dans le pod :

    kubectl port-forward POD_NAME 9001:9001
    
  2. Dans la barre d'adresse de votre navigateur, saisissez les informations suivantes :

    http://localhost:9001/
    

    La visionneuse de trace XProf s'ouvre.

  3. Dans la fenêtre TensorBoard, cliquez sur Capturer le profil.

  4. Dans le champ URL de service du profil ou nom du TPU, saisissez localhost:9002.

  5. Pour capturer plus de détails, dans Niveau de trace de l'hôte (TraceMe), sélectionnez verbose et activez la journalisation des traces Python.

  6. Pour afficher le tableau de bord, cliquez sur Capture.

    TensorBoard capture le profil et vous permet d'analyser les performances du script d'entraînement. Le graphique affiche la chronologie d'exécution pour les profils de performances TPU et CPU :

Exemple du traceur XProf affichant un graphique de matrice de performances

Pour découvrir d'autres options de profilage permettant d'analyser les performances de votre charge de travail d'entraînement, consultez la documentation JAX sur le profilage du calcul.

Finetuning dans les environnements de production

Ce tutoriel vous a montré comment tester l'entraînement basé sur JAX dans un environnement distribué. Pour un affinement optimisé des LLM en production, utilisez la bibliothèque MaxText. Si vous êtes intéressé par les modèles de diffusion, utilisez les implémentations Maxdiffusion.

Pour les charges de travail d'entraînement ou de réglage fin de longue durée en production, configurez la gestion des points de contrôle des charges de travail afin de minimiser la perte de progression en cas d'échec. Pour en savoir plus sur la configuration du checkpointing à plusieurs niveaux, consultez Entraîner des modèles de machine learning à grande échelle sur GKE avec le checkpointing à plusieurs niveaux.

Effectuer un nettoyage

Pour éviter que les ressources utilisées lors de ce tutoriel soient facturées sur votre compte Google Cloud, supprimez le projet contenant les ressources, ou conservez le projet et supprimez les ressources individuelles.

Supprimer les ressources individuelles

Pour éviter que les ressources utilisées dans ce tutoriel ne soient facturées sur votre compte Google Cloud , supprimez le projet contenant les ressources, ou conservez le projet et supprimez les ressources individuelles en exécutant les commandes suivantes :

  1. Supprimez les ressources que vous avez créées dans ce tutoriel :

    gcloud container clusters delete ${CLUSTER_NAME} --location=${REGION}
    gcloud storage rm --recursive gs://${GCS_BUCKET_NAME}
    gcloud artifacts docker images delete ${IMAGE_URI} --delete-tags
    
  2. Si vous n'avez pas besoin des données générées par XProf, supprimez le bucket Cloud Storage utilisé par XProf :

    gcloud storage rm --recursive gs://${XPROF_GCS_BUCKET_NAME}
    

Étapes suivantes