GKE에서 JAX를 사용하여 TPU로 LLM 미세 조정

이 튜토리얼에서는 JAX를 사용하여 Google Kubernetes Engine (GKE)에서 텐서 처리 장치 (TPU)를 사용하여 대규모 언어 모델 (LLM)을 미세 조정하는 방법을 보여줍니다. 미세 조정을 사용하면 Gemma 3와 같은 파운데이션 모델을 특정 도메인 또는 태스크에 맞게 조정할 수 있습니다. 이 프로세스는 자체 전문 데이터 세트로 모델의 파라미터를 업데이트하여 모델의 정밀도와 정확도를 개선합니다.

이 가이드는 AI/ML 워크로드를 파인 튜닝할 때 관리형 Kubernetes의 세밀한 제어, 맞춤설정, 확장성, 복원력, 이동성, 비용 효율성이 필요한 경우 좋은 출발점이 될 수 있습니다.

배경

Jax와 함께 GKE에서 TPU를 사용하여 LLM을 미세 조정하면 관리형 Kubernetes의 모든 이점을 갖춘 강력한 프로덕션에 즉시 사용 가능한 미세 조정 솔루션을 빌드할 수 있습니다.

Gemma

Gemma는 오픈 라이선스로 출시된 공개적으로 사용 가능한 가벼운 생성형 AI/ML 멀티모달 모델의 집합입니다. 이러한 AI 모델은 애플리케이션, 하드웨어, 휴대기기 또는 호스팅된 서비스에서 실행할 수 있습니다. Gemma 3는 멀티모달 기능을 도입하여 시각-언어 입력과 텍스트 출력을 지원합니다. 최대 128,000개 토큰의 컨텍스트 윈도우를 처리하고 140개 이상의 언어를 지원합니다. 또한 Gemma 3는 구조화된 출력과 함수 호출을 포함한 수학, 추론, 채팅 기능이 개선되었습니다.

텍스트 생성에 Gemma 모델을 사용하거나, 특수한 태스크를 위해 이러한 모델을 조정할 수도 있습니다.

자세한 내용은 Gemma 문서를 참고하세요.

TPU

TPU는 Google이 TensorFlow, PyTorch, JAX와 같은 프레임워크를 사용하여 빌드된 머신러닝 및 AI 모델을 가속화하기 위해 맞춤 개발한 ASIC (application-specific integrated circuits)입니다.

GKE에서 TPU를 사용하기 전에 다음 학습 과정을 완료하는 것이 좋습니다.

  1. Cloud TPU 시스템 아키텍처를 사용하는 현재 TPU 버전 가용성 알아보기
  2. GKE의 TPU에 대해 알아봅니다.

JAX

JAX는 TPU 및 GPU와 함께 사용하도록 설계된 고성능 머신러닝 프레임워크입니다. JAX는 머신러닝 모델을 빌드하고 학습하기 위한 API를 제공합니다.

자세한 내용은 JAX 저장소를 참고하세요.

목표

이 튜토리얼은 다음 과정을 다룹니다.

  1. 모델 특성에 따라 권장 TPU 토폴로지를 사용하여 GKE Autopilot 또는 Standard 클러스터를 만듭니다. 이 튜토리얼에서는 단일 호스트 노드 풀에서 미세 조정 작업을 수행합니다.
  2. Cloud Storage 버킷에 데이터를 추가하고 Cloud Storage FUSE를 통해 컨테이너에 마운트합니다.
  3. GKE에 LLM 미세 조정 작업을 배포합니다.
  4. 미세 조정 작업을 모니터링하고 로그를 확인합니다.

시작하기 전에

  • Google Cloud 계정에 로그인합니다. Google Cloud를 처음 사용하는 경우 계정을 만들고 Google 제품의 실제 성능을 평가해 보세요. 신규 고객에게는 워크로드를 실행, 테스트, 배포하는 데 사용할 수 있는 $300의 무료 크레딧이 제공됩니다.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  • 프로젝트에 다음 역할이 있는지 확인합니다. roles/container.admin,roles/iam.serviceAccountAdmin,roles/storage.admin

    역할 확인

    1. Google Cloud 콘솔에서 IAM 페이지로 이동합니다.

      IAM으로 이동
    2. 프로젝트를 선택합니다.
    3. 주 구성원 열에서 나 또는 내가 속한 그룹을 식별하는 모든 행을 찾습니다. 내가 속한 그룹을 알아보려면 관리자에게 문의하세요.

    4. 나를 지정하거나 포함하는 모든 행의 역할 열을 확인하여 역할 목록에 필요한 역할이 포함되어 있는지 확인합니다.

    역할 부여

    1. Google Cloud 콘솔에서 IAM 페이지로 이동합니다.

      IAM으로 이동
    2. 프로젝트를 선택합니다.
    3. 액세스 권한 부여를 클릭합니다.
    4. 새 주 구성원 필드에 사용자 식별자를 입력합니다. 일반적으로 Google 계정의 이메일 주소입니다.

    5. 역할 선택을 클릭한 후 역할을 검색합니다.
    6. 역할을 추가로 부여하려면 다른 역할 추가를 클릭하고 각 역할을 추가합니다.
    7. 저장을 클릭합니다.
  • TPU Trillium (v6e) 칩 16개에 대해 충분한 할당량이 있는지 확인합니다. 이 튜토리얼에서는 16개의 칩과 온디맨드 인스턴스가 필요한 노드 풀 구성을 사용합니다.
  • Docker 저장소가 있는지 확인합니다. 저장소가 없으면 Artifact Registry에서 표준 저장소를 만듭니다.

환경 준비

이 튜토리얼에서는 Cloud Shell을 사용하여 Google Cloud에서 호스팅되는 리소스를 관리합니다. Cloud Shell에는 kubectlGoogle Cloud CLI 등 이 튜토리얼에 필요한 소프트웨어가 사전 설치되어 있습니다.

Cloud Shell로 환경을 설정하려면 다음 단계를 따르세요.

  1. Google Cloud 콘솔에서 Cloud Shell 세션을 시작하고 Cloud Shell 활성화 아이콘 Cloud Shell 활성화를 클릭합니다. 그러면 Google Cloud 콘솔 하단 창에서 세션이 실행됩니다.

  2. 기본 환경 변수를 설정합니다.

    gcloud config set project PROJECT_ID
    gcloud config set billing/quota_project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=CONTROL_PLANE_LOCATION
    export ZONE=ZONE
    export GCS_BUCKET_NAME=BUCKET_NAME
    

    다음 값을 바꿉니다.

    • PROJECT_ID: Google Cloud 프로젝트 ID입니다.
    • CLUSTER_NAME: GKE 클러스터의 이름입니다.
    • CONTROL_PLANE_LOCATION: GKE 클러스터와 TPU 노드가 있는 Compute Engine 리전입니다. 리전에는 TPU Trillium (v6e) 머신 유형을 사용할 수 있는 영역이 포함되어야 합니다.
    • ZONE: 선택한 CONTROL_PLANE_LOCATION 리전 내에서 TPU Trillium (v6e) 머신 유형을 사용할 수 있는 영역입니다. TPU Trillium (v6e) TPU를 사용할 수 있는 영역을 나열하려면 다음 명령어를 실행하세요.

        gcloud compute accelerator-types list --filter="name~ct6e" --format="value(zone)"
      
    • BUCKET_NAME: 학습 데이터가 포함된 Cloud Storage 버킷의 이름입니다.

  3. 샘플 저장소를 클론합니다.

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    cd kubernetes-engine-samples
    
  4. 작업 디렉터리로 이동합니다.

    cd ai-ml/llm-training-jax-tpu-gemma3
    

Google Cloud 리소스 만들기 및 구성

이 섹션에서는 Google Cloud 리소스를 만들고 구성합니다.

GKE 클러스터 만들기

GKE Autopilot 또는 Standard 클러스터의 TPU에서 LLM을 미세 조정할 수 있습니다. 완전 관리형 Kubernetes 환경을 위해서는 Autopilot 클러스터를 사용하는 것이 좋습니다. 워크로드에 가장 적합한 GKE 작업 모드를 선택하려면 GKE 작업 모드 선택을 참고하세요.

Autopilot

GKE용 워크로드 아이덴티티 제휴를 사용하고 Cloud Storage FUSE가 사용 설정된 GKE Autopilot 클러스터를 만듭니다.

gcloud container clusters create-auto ${CLUSTER_NAME} \
    --location=${REGION}

클러스터 만들기는 몇 분 정도 걸릴 수 있습니다.

표준

  1. GKE용 워크로드 아이덴티티 제휴를 사용하고 Cloud Storage FUSE가 사용 설정된 리전의 GKE Standard 클러스터를 만듭니다.

    gcloud container clusters create ${CLUSTER_NAME} \
        --enable-ip-alias \
        --addons GcsFuseCsiDriver \
        --machine-type=n2-standard-4 \
        --num-nodes=2 \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --location=${REGION}
    

    클러스터 만들기는 몇 분 정도 걸릴 수 있습니다.

  2. 단일 호스트 노드 풀을 만듭니다.

    gcloud container node-pools create jax-tpu-nodepool \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-1t \
        --num-nodes=1 \
        --location=${REGION} \
        --node-locations=${ZONE} \
        --workload-metadata=GKE_METADATA
    

GKE는 1x1 토폴로지와 1개의 노드가 있는 TPU Trillium 노드 풀을 만듭니다. --workload-metadata=GKE_METADATA 플래그는 GKE 메타데이터 서버를 사용하도록 노드 풀을 구성합니다.

JobSet 설치

  1. 클러스터와 통신하도록 kubectl을 구성합니다.

    gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}
    
  2. 최신 출시 버전의 JobSet을 설치합니다.

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    JOBSET_VERSION최신 출시 버전의 JobSet으로 바꿉니다. 예를 들면 v0.11.0입니다.

  3. JobSet 설치를 확인합니다.

    kubectl get pods -n jobset-system
    

    출력은 다음과 비슷합니다.

    NAME                                         READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-6c56668494-l4dhc   1/1     Running   0          4m45s
    

    JobSet이 리소스를 기다리는 경우 노드를 더 추가해야 할 수 있습니다.

Cloud Storage FUSE 구성

LLM을 미세 조정하려면 학습 데이터를 제공해야 합니다. 이 튜토리얼에서는 Hugging Face의 TinyStories 데이터 세트를 사용합니다. 이 데이터 세트에는 제한된 어휘를 사용하는 GPT-3.5 및 GPT-4에 의해 합성적으로 생성된 단편 소설이 포함되어 있습니다.

이 섹션에서는 Cloud Storage 버킷에서 데이터를 읽도록 Cloud Storage FUSE를 구성하는 단계를 설명합니다.

  1. 데이터 세트를 다운로드합니다.

    wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt
    
  2. 새 Cloud Storage 버킷에 데이터를 업로드합니다.

    gcloud storage buckets create gs://${GCS_BUCKET_NAME} \
        --location=${REGION} \
        --enable-hierarchical-namespace \
        --uniform-bucket-level-access
    gcloud storage cp TinyStories-train.txt gs://${GCS_BUCKET_NAME}
    
  3. 워크로드가 Cloud Storage FUSE를 통해 데이터를 읽을 수 있도록 하려면 Kubernetes 서비스 계정 (KSA)을 만들고 필요한 권한을 추가합니다. permissionsetup.sh 스크립트를 실행합니다.

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    #!/bin/bash
    
    # --- Configuration Variables ---
    # Kubernetes Service Account details
    export KSA_NAME="jaxserviceaccout"
    export NAMESPACE="default"
    
    # Google Cloud IAM Service Account details
    export GSA_NAME="<GSA_NAME>"
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    export  GSA_DESCRIPTION="GKE Service Account to read GCS bucket for ${KSA_NAME}"
    
    # GCS Bucket details
    export GCS_BUCKET_NAME="<GCS_BUCKET_NAME>" # <--- IMPORTANT: Update this to your bucket name
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    export WI_MEMBER="serviceAccount:${PROJECT_ID}.svc.id.goog[${NAMESPACE}/${KSA_NAME}]"
    
    # --- Check if PROJECT_ID is set ---
    if [ -z "${PROJECT_ID}" ]; then
      echo "Error: PROJECT_ID is not set. Please set it using 'gcloud config set project YOUR_PROJECT_ID'"
      exit 1
    fi
    
    echo "--- Configuration ---"
    echo "KSA_NAME:      ${KSA_NAME}"
    echo "NAMESPACE:     ${NAMESPACE}"
    echo "GSA_NAME:      ${GSA_NAME}"
    echo "PROJECT_ID:    ${PROJECT_ID}"
    echo "GSA_EMAIL:     ${GSA_EMAIL}"
    echo "GCS_BUCKET_NAME:   ${GCS_BUCKET_NAME}"
    echo "WI_MEMBER:     ${WI_MEMBER}"
    echo "--------------------"
    read -p "Press enter to continue..."
    
    # --- Command Execution ---
    
    echo "[1/5] Creating Google Cloud IAM Service Account (GSA): ${GSA_NAME}"
    gcloud iam service-accounts create "${GSA_NAME}" \
        --project="${PROJECT_ID}" \
        --description="${GSA_DESCRIPTION}" \
        --display-name="${GSA_NAME}"
    
    echo "[2/5] Granting GSA '${GSA_EMAIL}' read access (roles/storage.objectViewer) to bucket 'gs://${GCS_BUCKET_NAME}'"
    gcloud storage buckets add-iam-policy-binding "gs://${GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectViewer" \
        --project="${PROJECT_ID}"
    
    echo "[3/5] Creating Kubernetes Service Account (KSA): ${KSA_NAME} in namespace ${NAMESPACE}"
    kubectl create serviceaccount "${KSA_NAME}" --namespace "${NAMESPACE}"
    
    echo "[4/5] Allowing KSA to impersonate GSA (Workload Identity Binding): ${GSA_EMAIL}"
    gcloud iam service-accounts add-iam-policy-binding "${GSA_EMAIL}" \
        --role roles/iam.workloadIdentityUser \
        --member "${WI_MEMBER}" \
        --project="${PROJECT_ID}"
    
    echo "[5/5] Annotating KSA '${KSA_NAME}' to link with GSA '${GSA_EMAIL}'"
    kubectl annotate serviceaccount "${KSA_NAME}" \
        --namespace "${NAMESPACE}" \
        iam.gke.io/gcp-service-account="${GSA_EMAIL}"
    
    echo "--- Setup Complete ---"
    echo "Pods in namespace '${NAMESPACE}' using serviceAccount '${KSA_NAME}' can now authenticate as '${GSA_EMAIL}' and have read access to 'gs://${GCS_BUCKET_NAME}'."
    

    이 스크립트를 실행하면Google Cloud 프로젝트 및 GKE 클러스터에 다음 리소스가 구성됩니다.

    • gcs-fuse-sa이라는 새 IAM 서비스 계정이 프로젝트에 생성됩니다.
    • 생성된 Google Cloud 서비스 계정 (GSA) (gcs-fuse-sa)에는 ${GCS_BUCKET_NAME}로 지정된 Cloud Storage 버킷에 대한 roles/storage.objectViewer 역할이 부여됩니다. 이 권한을 사용하면 GSA가 버킷에서 객체를 읽을 수 있습니다.
    • GKE 클러스터 내 default 네임스페이스에 jaxserviceaccount라는 새 KSA가 생성됩니다.
    • GSA의 IAM 정책이 업데이트되어 KSA에 roles/iam.workloadIdentityUser 역할이 부여됩니다. 이 권한을 사용하면 KSA가 GSA를 가장할 수 있습니다.
    • KSA에 주석을 추가하여 GSA에 연결합니다. 이 주석은 워크로드 아이덴티티를 사용하여 KSA가 가장해야 하는 GSA를 GKE에 알려줍니다.

      이제 jaxserviceaccount 서비스 계정을 사용하는 GKE 클러스터의 default 네임스페이스에서 실행되는 모든 포드가 gcs-fuse-sa GSA로 인증할 수 있습니다. 이러한 포드는 gs://${GCS_BUCKET_NAME} 버킷에 저장된 객체에 대한 읽기 액세스 권한을 갖게 되며, 이는 미세 조정 작업이 Cloud Storage FUSE를 사용하여 데이터 세트에 액세스하는 데 필수적입니다.

미세 조정 스크립트 만들기

이 섹션에서는 Gemma 3 모델에서 파인 튜닝 작업을 실행하는 학습 스크립트를 살펴봅니다. 이 스크립트는 Gemma3Tokenizer을 사용합니다.

다음 Gemma3LLMTrain.py 미세 조정 스크립트를 검토하세요.

# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import grain.python as pygrain
import jax
import jax.numpy as jnp
import optax
import pandas as pd
import time
import argparse

from dataclasses import dataclass
from functools import partial
from gemma import gm
from flax.training import train_state
from jax.sharding import Mesh, PartitionSpec, NamedSharding

jax.distributed.initialize()
print("Global device count:", jax.device_count())
print("jax version:", jax.__version__)

tokenizer = gm.text.Gemma3Tokenizer()
num_epochs = 1
learning_rate = 2e-5

@dataclass
class TextDataset:
    data: list
    maxlen: int
    tokenizer: gm.text.Gemma3Tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        encoding = self.tokenizer.encode(self.data[idx])[:self.maxlen]  # Tokenize and truncate
        return encoding + [0] * (self.maxlen - len(encoding))  # Pad to maxlen

def load_and_preprocess_data(file_path, batch_size, maxlen, datacount, tokenizer):

    with open(file_path, 'r') as f:
      text = f.read()

    stories = text.split('<|endoftext|>')
    stories = [story for story in stories if story.strip()][:datacount]
    df = pd.DataFrame({'text': stories})
    data = df['text'].dropna().tolist()
    dataset = TextDataset(data, maxlen, tokenizer)

    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=False,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )

    dataloader = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )

    return dataloader

def generate_text(model, params, tokenizer, prompt):
    sampler = gm.text.Sampler(
        model=model,
        params=params,
        tokenizer=tokenizer,
    )
    print("Generating response for: " + prompt)
    out = sampler.sample(prompt, max_new_tokens=32)
    print("Reponse: \n" + out + "\n")
    return out

prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))

@partial(jax.jit, donate_argnums=(0,))
def train_step(state, batch):
    """Performs one supervised fine-tuning step."""

    def loss_fn(params):
        # Run the forward pass. The model returns logits.
        logits = state.apply_fn({'params': params}, batch[0]).logits

        # Calculate the cross-entropy loss.
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch[1]
        ).mean()

        return loss

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)

    # Update the model state
    state = state.apply_gradients(grads=grads)

    metrics = {'loss': loss}
    return state, metrics

def train_model(state, text_dl, num_epochs, sharding):
    batchCount = 0
    start_time = time.time()
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch in text_dl:
            if len(batch) % len(jax.devices()) != 0:
              continue  # skip the remaining elements
            input_batch = jnp.array(jnp.array(batch).T)
            target_batch = prep_target_batch(input_batch)
            state, metrics = train_step(state, jax.device_put((input_batch, target_batch), sharding))

            if batchCount % 10 == 0:
                print(f"Loss after batch {batchCount}: {metrics['loss']}")
            batchCount += 1

    end_time = time.time()
    print(f"Completed training model. Total time for training {end_time - start_time} seconds \n")
    return state

def run_training(maxlen, batch_size, datacount):
    print(f"Batch size: {batch_size}, Max length: {maxlen}, Data count: {datacount}")
    #Load the training data
    tiny_stories_dl = load_and_preprocess_data('/data/TinyStories-train.txt', batch_size, maxlen, datacount, tokenizer)
    # Get the Gemma3 model
    model = gm.nn.Gemma3_270M()
    # Load the pretrained parameters
    params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_270M_PT)
    # Create an optimizer
    optimizer = optax.adamw(learning_rate=learning_rate)
    # Define sharding for data parallel training
    mesh = Mesh(jax.devices(), ('batch',))
    sharding = NamedSharding(mesh, PartitionSpec('batch', None))

    # Testing out current state of the model
    test_prompt = "Once upon a time, there was a girl named Amy."
    generate_text(model, params, tokenizer, test_prompt)

    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer
    )

    # Perform post training
    print("Start training model")
    state = train_model(state, tiny_stories_dl, num_epochs, sharding)

    # Final text generation
    generate_text(model, state.params, tokenizer, test_prompt)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Gemma model with custom parameters.')
    parser.add_argument('--maxlen', type=int, default=256, help='Maximum sequence length')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
    parser.add_argument('--datacount', type=int, default=296000, help='Number of data samples to use')
    args = parser.parse_args()

    run_training(maxlen=args.maxlen, batch_size=args.batch_size, datacount=args.datacount)

이 스크립트에서는 다음이 적용됩니다.

  • Gemma3Tokenizer은 텍스트 데이터를 모델이 처리할 수 있는 토큰으로 변환합니다.
  • load_and_preprocess_data 함수는 파일에서 학습 데이터를 읽고, 개별 스토리로 분할하고, 토큰화 도구를 사용하여 텍스트를 패딩된 토큰 시퀀스로 변환합니다.
  • generate_text 함수는 모델, 모델의 매개변수, 프롬프트를 사용하여 텍스트를 생성합니다.
  • train_step 함수는 순방향 패스, 손실 계산 (교차 엔트로피 사용), 그라데이션 계산, 파라미터 업데이트를 포함하는 단일 학습 반복을 정의합니다.
  • train_model 함수는 지정된 수의 에포크 동안 데이터 세트를 반복하여 각 배치에 대해 train_step 함수를 호출합니다.
  • run_training 함수는 데이터를 로드하고, Gemma 3 모델 (Gemma3_270M)과 옵티마이저를 초기화하고, 사전 학습된 매개변수를 로드하고, 병렬 처리를 위해 데이터 샤딩을 설정하고, 테스트 생성을 실행하고, 학습 루프를 실행하고, 파인 튜닝의 효과를 보여주기 위해 최종 텍스트 생성을 실행하는 전체 프로세스를 오케스트레이션합니다.
  • 스크립트는 argparse 라이브러리를 사용하여 maxlen, batch_size, datacount 매개변수의 명령줄 인수를 허용합니다.

이제 미세 조정 스크립트를 살펴보았으므로 GKE에서 실행되도록 컨테이너화합니다.

미세 조정 스크립트 컨테이너화

GKE 클러스터에서 미세 조정 스크립트를 실행하기 전에 컨테이너화해야 합니다. 이 튜토리얼에서는 JAX AI 이미지를 기본 이미지로 사용합니다.

  1. Gemma3LLMTrain.py 파일과 동일한 디렉터리에서 Dockerfile을 엽니다.

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    이 Dockerfile은 필요한 종속 항목을 설치하고 Gemma3LLMTrain.py 파일을 컨테이너에 복사합니다.

  2. Docker 이미지를 빌드하고 이미지 저장소에 푸시합니다.

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    REPOSITORY_NAME을 Artifact Registry 저장소의 이름으로 바꿉니다.

  3. 서비스 계정에 역할 바인딩을 추가합니다.

    export PROJECT_NUMBER=$(gcloud projects describe $PROJECT_ID --format 'get(projectNumber)')
    gcloud artifacts repositories add-iam-policy-binding ${REPOSITORY} \
        --project=${PROJECT_ID} \
        --location=${REGION} \
        --member="serviceAccount:${PROJECT_NUMBER}-compute@developer.gserviceaccount.com" \
        --role="roles/artifactregistry.reader"
    

저장소에 이미지가 있으므로 이제 미세 조정 작업을 GKE 클러스터에 배포할 수 있습니다.

LLM 미세 조정 작업 배포

이 섹션에서는 LLM 파인 튜닝 작업을 GKE 클러스터에 배포하는 방법을 보여줍니다.

  1. training_singlehost.yaml 매니페스트를 엽니다.

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "355120"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. 매니페스트를 적용합니다.

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

GKE는 TPU Trillium (v6e) 노드에서 포드를 실행하는 작업을 만듭니다. 이 포드는 Python 미세 조정 스크립트를 실행합니다. 이 스크립트는 Cloud Storage FUSE를 사용하여 /data 경로에 마운트된 지정된 Cloud Storage 버킷에서 미세 조정 데이터에 액세스합니다. 그런 다음 스크립트가 Gemma 모델을 미세 조정합니다.

학습 작업 모니터링

이 섹션에서는 파인 튜닝 작업의 진행 상황과 성능을 모니터링합니다.

파인 튜닝 진행 상황 보기

  1. 포드를 나열합니다.

    # Find the Pods
    kubectl get pods
    
  2. 로그 출력을 따릅니다.

    kubectl logs -f pods/POD_NAME
    

    POD_NAME을 포드 이름으로 바꿉니다.

    출력은 다음과 비슷합니다.

    Global device count: 1
    Batch size: 128, Max length: 256, Data count: 96000
    I1028 00:12:55.925999 1387 google_auth_provider.cc:181] Running on GCE, using service account ...
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    Amy lived in a small house. The house was in a big field. Amy liked to play in the big field. She
    Start training model
    Loss after batch 0: 10.25
    Loss after batch 10: 4.3125
    .
    .
    .
    Loss after batch 740: 1.41406
    Completed training model. Total time for training 294.6791355609894 seconds
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    She loved to play with her toys. One day, Amy's mom told her that she had to go to the store to
    
  3. 출력을 분석합니다.

    • Global device count: 1 선은 사용된 TPU 코어를 나타냅니다.
    • 모델은 사전 학습된 체크포인트에서 로드되므로 이 미세 조정 실행 전에 합리적인 텍스트를 생성합니다.
    • 미세 조정 후 생성된 출력은 단편 소설의 시작 부분과 더 유사하며, 이는 모델이 새 데이터 세트에서 학습하고 있음을 나타냅니다.
    • 전체 데이터 세트를 미세 조정하면 더욱 세련된 출력이 생성됩니다.

측정항목 관찰

TPU 및 CPU 측정항목을 확인하여 미세 조정 작업의 성능을 확인합니다. 클러스터의 관측 가능성 측정항목을 보려면 클러스터 및 워크로드 관측 가능성 측정항목 보기의 단계를 따르세요.

대체 미세 조정 구성

이 섹션에서는 미세 조정 워크로드의 대체 구성을 간략하게 설명합니다.

모델 선택

이 튜토리얼에서는 단일 호스트 TPU Trillium (v6e) 노드 풀에 적합한 소규모 모델인 Gemma3_270M 모델을 사용했습니다. 파인 튜닝에 더 많은 메모리와 컴퓨팅이 필요한 대규모 모델의 경우 멀티 호스트 또는 멀티 슬라이스 노드 풀 구성을 사용할 수 있습니다.

사용 가능한 모델의 전체 목록은 Gemma 문서를 참고하세요.

노드 풀 구성

이 튜토리얼에서는 단일 호스트 노드 풀을 사용했습니다. 필요에 따라 멀티 호스트 TPU 슬라이스 노드 풀 또는 멀티슬라이스 노드 풀을 만들 수도 있습니다.

다음 탭에서는 멀티 호스트 및 멀티 슬라이스 노드 풀을 만드는 방법을 보여줍니다.

멀티 호스트

  1. Cloud Shell에서 다음 명령어를 실행합니다.

    gcloud container node-pools create jax-tpu-multihost1 \
        --cluster=${CLUSTER_NAME} \
        --machine-type=ct6e-standard-4t \
        --num-nodes=2 \
        --tpu-topology=2x4 \
        --location=${REGION} \
        --node-locations=${ZONE}
    

    GKE는 2x4 토폴로지와 2개의 노드로 TPU Trillium 노드 풀을 만듭니다.

  2. training_multihost_jobset.yaml 작업 정의를 엽니다.

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multihost
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 1
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                    cloud.google.com/gke-nodepool: jax-tpu-multihost1
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI} 
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
    
  3. 파인 튜닝 작업을 배포합니다.

    envsubst < training_multihost_jobset.yaml | kubectl apply -f -
    

멀티슬라이스

  1. Cloud Shell에서 다음 명령어를 실행합니다.

    gcloud container node-pools create jax-tpu-multihost1 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    
    gcloud container node-pools create jax-tpu-multihost2 \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct6e-standard-4t \
      --num-nodes=2 \
      --tpu-topology=2x4 \
      --location=${REGION} \
      --node-locations=${ZONE}
    

    GKE는 TPU Trillium 노드 풀 2개를 만듭니다. 각 노드 풀에는 2x4 토폴로지와 노드 2개가 있습니다.

  2. training_multislice_jobset.yaml 작업 정의를 엽니다.

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: jax-gemma3-train-multislice
    spec:
      replicatedJobs:
        - name: trainers
          replicas: 2
          template:
            spec:
              parallelism: 2
              completions: 2
              backoffLimit: 1
              template:
                metadata:
                  annotations:
                    gke-gcsfuse/volumes: "true"
                spec:
                  serviceAccountName: jaxserviceaccout
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                    cloud.google.com/gke-tpu-topology: 2x4
                  containers:
                  - name: training-container
                    image: ${IMAGE_URI}
                    imagePullPolicy: "Always"
                    ports:
                      - containerPort: 8471
                    command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "5120"]
                    resources:
                      limits:
                        google.com/tpu: 4
                    volumeMounts:
                    - name: gcs-fuse-csi-ephemeral
                      mountPath: /data
                  volumes:
                    - name: gcs-fuse-csi-ephemeral
                      csi:
                        driver: gcsfuse.csi.storage.gke.io
                        volumeAttributes:
                          bucketName: ${GCS_BUCKET_NAME}
                          mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:100"
    
  3. 파인 튜닝 작업을 배포합니다.

    envsubst < training_multislice_jobset.yaml | kubectl apply -f -
    

성능 분석 및 최적화

머신러닝 미세 조정의 성능을 분석하고 최적화하려면 XProf를 사용하면 됩니다. XProf는 JAX, TensorFlow 또는 PyTorch/XLA로 빌드된 ML 워크로드를 프로파일링하고 검사하는 도구 모음입니다. 실행 추적, 메모리 사용량, 기타 데이터를 표시하여 XProf를 사용하면 효율성을 높이고 학습 속도를 높이도록 모델과 학습 설정을 미세 조정할 수 있습니다.

XProf를 사용하여 미세 조정 워크로드의 성능을 분석하려면 이 섹션에서 다음 단계를 완료하세요.

  • xprof 패키지를 설치합니다. XProf 서버를 시작하도록 학습 스크립트를 수정합니다.
  • XProf 로그의 볼륨 마운트를 포함하도록 Kubernetes 작업 매니페스트를 수정합니다.
  • 서비스 계정에 Cloud Storage 버킷에 XProf 로그를 쓸 수 있는 권한을 부여합니다.
  • 포드 내에서 XProf를 실행하고 XProf 대시보드에 액세스하도록 포트 전달을 설정합니다.

XProf 패키지 설치

  1. XProf 샘플이 포함된 디렉터리로 이동합니다.

      cd ai-ml/llm-training-jax-tpu-gemma3/xprof-enabled
    
  2. Docker 이미지를 빌드하고 이미지 저장소에 푸시합니다.

    export REPOSITORY=REPOSITORY_NAME
    export IMAGE_NAME="jax-gemma3-training-xp"
    export IMAGE_TAG="latest"
    export DOCKERFILE_PATH="./Dockerfile"
    export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}"
    
    docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" .
    gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q
    docker push "${IMAGE_URI}"
    

    REPOSITORY_NAME을 Artifact Registry 저장소의 이름으로 바꿉니다.

  3. Dockerfile 스크립트를 실행합니다.

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1
    RUN apt-get update && apt-get install -y wget && rm -rf /var/lib/apt/lists/*
    
    RUN pip install --upgrade pip
    RUN pip install gemma grain equinox
    RUN pip install xprof
    
    WORKDIR /app
    
    # Copy your training script into the container
    COPY Gemma3LLMTrain.py .
    

    이 Dockerfile은 XProf 종속 항목을 설치합니다.

미세 조정 스크립트를 컨테이너에 복사합니다.

이 섹션에서는 XProf 로그에 필요한 볼륨 마운트가 포함된 Kubernetes 작업 매니페스트를 만들고 적용합니다.

  1. training_singlehost.yaml 작업 정의를 엽니다.

    # Copyright 2026 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    apiVersion: batch/v1
    kind: Job
    metadata:
      name: jax-gemma3-train-singlehost
    spec:
      template:
        metadata:
          annotations:
            gke-gcsfuse/volumes: "true"
        spec:
          serviceAccountName: jaxserviceaccout
          containers:
          - name: training-container
            image: ${IMAGE_URI}
            imagePullPolicy: "Always"
            command: ["python", "Gemma3LLMTrain.py", "--maxlen", "256", "--batch_size", "64", "--datacount", "851200"]
            resources:
              limits:
                google.com/tpu: 1
            volumeMounts:
            - name: gcs-fuse-csi-ephemeral
              mountPath: /data
            - name: gcs-fuse-csi-ephemeral2
              mountPath: /xprof
          nodeSelector:
            cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
            cloud.google.com/gke-tpu-topology: 1x1
          restartPolicy: Never
          volumes:
          - name: gcs-fuse-csi-ephemeral
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
          - name: gcs-fuse-csi-ephemeral2
            csi:
              driver: gcsfuse.csi.storage.gke.io
              volumeAttributes:
                bucketName: ${XPROF_GCS_BUCKET_NAME}
                mountOptions: "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1"
      backoffLimit: 1
  2. 매니페스트를 적용합니다.

    envsubst < training_singlehost.yaml | kubectl apply -f -
    

서비스 계정에 XProf 로그를 쓸 수 있는 권한 부여

  1. 서비스 계정이 쓰고 읽을 수 있도록 "roles/storage.objectUser" 역할을 추가합니다.

    export GSA_NAME="GSA_NAME" # Same as used in initial setup
    
    # Automatically get the current project ID
    export PROJECT_ID=$(gcloud config get-value project)
    
    # Cloud Storage Bucket details
    export XPROF_GCS_BUCKET_NAME="XPROF_GCS_BUCKET_NAME"
    
    # Derived Variables
    export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com"
    
    gcloud storage buckets add-iam-policy-binding "gs://${XPROF_GCS_BUCKET_NAME}" \
        --member="serviceAccount:${GSA_EMAIL}" \
        --role="roles/storage.objectUser" \
        --project="${PROJECT_ID}"
    

    다음을 바꿉니다.

    • GSA_NAME: 역할을 부여할 Google 서비스 계정의 이름입니다.
    • XPROF_GCS_BUCKET_NAME: 역할을 부여할 버킷의 이름입니다.
  2. 포드 내에서 XProf를 실행합니다.

    kubectl exec POD_NAME -c training-container -it -- bash # exec into the container
    xprof --port 9001 --logdir /xprof # start xprof
    

    POD_NAME을 포드 이름으로 바꿉니다.

XProf 대시보드에 액세스

  1. 포드에서 XProf 서버로 포트 전달을 설정합니다.

    kubectl port-forward POD_NAME 9001:9001
    
  2. 브라우저의 주소 표시줄에 다음을 입력합니다.

    http://localhost:9001/
    

    XProf 추적 뷰어가 열립니다.

  3. 텐서보드 창에서 프로필 캡처를 클릭합니다.

  4. 프로필 서비스 URL 또는 TPU 이름 필드에 localhost:9002를 입력합니다.

  5. 자세한 내용을 캡처하려면 호스트 추적 (TraceMe) 수준에서 verbose를 선택하고 Python 추적 로깅을 사용 설정합니다.

  6. 대시보드를 보려면 캡처를 클릭합니다.

    TensorBoard는 프로필을 캡처하고 학습 스크립트의 성능을 분석할 수 있도록 지원합니다. 그래프에는 TPU와 CPU 성능 프로필의 실행 타임라인이 표시됩니다.

성능 매트릭스 그래프를 보여주는 XProf 트레이스 뷰어의 예

학습 워크로드 성능을 분석하기 위한 추가 프로파일링 옵션은 계산 프로파일링에 관한 JAX 문서를 참고하세요.

프로덕션 환경에서 미세 조정

이 튜토리얼에서는 분산 환경에서 JAX 기반 학습을 테스트하는 방법을 살펴보았습니다. 프로덕션에서 최적화된 LLM 미세 조정에는 Maxtext 라이브러리를 사용하세요. 확산 모델에 관심이 있다면 Maxdiffusion 구현을 사용하세요.

프로덕션에서 장기 실행 학습 또는 미세 조정 워크로드의 경우 장애 발생 시 진행률 손실을 최소화하도록 워크로드 체크포인트를 설정하세요. 멀티 계층 체크포인트 설정에 대해 자세히 알아보려면 멀티 계층 체크포인트를 사용하여 GKE에서 대규모 머신러닝 모델 학습을 참고하세요.

삭제

이 튜토리얼에서 사용된 리소스 비용이 Google Cloud 계정에 청구되지 않도록 하려면 리소스가 포함된 프로젝트를 삭제하거나 프로젝트를 유지하고 개별 리소스를 삭제하세요.

개별 리소스 삭제

이 튜토리얼에서 사용된 리소스 비용이 Google Cloud 계정에 청구되지 않도록 하려면 리소스가 포함된 프로젝트를 삭제하거나 다음 명령어를 실행하여 프로젝트는 유지하되 개별 리소스를 삭제하세요.

  1. 이 튜토리얼에서 만든 리소스를 삭제합니다.

    gcloud container clusters delete ${CLUSTER_NAME} --location=${REGION}
    gcloud storage rm --recursive gs://${GCS_BUCKET_NAME}
    gcloud artifacts docker images delete ${IMAGE_URI} --delete-tags
    
  2. XProf에서 생성된 데이터가 필요하지 않으면 XProf에서 사용한 Cloud Storage 버킷을 삭제합니다.

    gcloud storage rm --recursive gs://${XPROF_GCS_BUCKET_NAME}
    

다음 단계