本教學課程說明如何使用 JAX,在 Google Kubernetes Engine (GKE) 上透過張量處理單元 (TPU) 微調大型語言模型 (LLM)。微調功能可讓您調整 Gemma 3 等基礎模型,使其適用於特定領域或工作。這個程序會使用您自己的專業資料集更新模型參數,藉此提升模型的精確度和準確率。
如果您需要精細控管、自訂、擴充、復原、移植及提高成本效益,並在微調 AI/機器學習工作負載時使用代管型 Kubernetes,這份指南就是不錯的起點。
背景
在 GKE 上使用 TPU 和 Jax 微調 LLM,即可建構完善且可用於正式環境的微調解決方案,同時享有代管型 Kubernetes 的所有優點。
Gemma
Gemma 是一組開放式輕量級生成式 AI/ML 多模態模型,以開放授權形式發布。您可以在應用程式、硬體、行動裝置或代管服務中執行這些 AI 模型。Gemma 3 導入多模態功能,支援視覺語言輸入和文字輸出。可處理最多 128,000 個權杖的脈絡窗口,並支援超過 140 種語言。Gemma 3 的數學、推論和聊天功能也獲得提升,包括結構化輸出和函式呼叫。
您可以使用 Gemma 模型生成文字,也可以調整這些模型來執行特定工作。
詳情請參閱 Gemma 說明文件。
TPU
TPU 是 Google 客製化開發的特殊應用積體電路 (ASIC),可加速機器學習和 AI 模型,這些模型是使用 TensorFlow、PyTorch 和 JAX 等架構建構而成。
在 GKE 中使用 TPU 之前,建議您先完成下列學習路徑:
- 如要瞭解目前可用的 TPU 版本,請參閱 Cloud TPU 系統架構。
- 瞭解 GKE 中的 TPU。
JAX
JAX 是一種高效能機器學習框架,專為搭配 TPU 和 GPU 使用而設計。JAX 提供 API,用於建構及訓練機器學習模型。
詳情請參閱 JAX 存放區。
目標
本教學課程包含下列步驟:
- 根據模型特性,建立具有建議 TPU 拓撲的 GKE Autopilot 或 Standard 叢集。在本教學課程中,您將對單一主機節點集區執行微調作業。
- 將資料新增至 Cloud Storage 值區,並透過 Cloud Storage FUSE 將資料掛接到容器。
- 在 GKE 上部署 LLM 微調作業。
- 監控微調作業並查看記錄。
事前準備
- 登入 Google Cloud 帳戶。如果您是 Google Cloud新手,歡迎 建立帳戶,親自評估產品在實際工作環境中的成效。新客戶還能獲得價值 $300 美元的免費抵免額,可用於執行、測試及部署工作負載。
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
Roles required to select or create a project
- Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
-
Create a project: To create a project, you need the Project Creator role
(
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles.
-
Verify that billing is enabled for your Google Cloud project.
Enable the required API.
Roles required to enable APIs
To enable APIs, you need the Service Usage Admin IAM role (
roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission. Learn how to grant roles.-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
Roles required to select or create a project
- Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
-
Create a project: To create a project, you need the Project Creator role
(
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles.
-
Verify that billing is enabled for your Google Cloud project.
Enable the required API.
Roles required to enable APIs
To enable APIs, you need the Service Usage Admin IAM role (
roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission. Learn how to grant roles.-
確認您在專案中具備下列角色: roles/container.admin、roles/iam.serviceAccountAdmin、roles/storage.admin
檢查角色
-
前往 Google Cloud 控制台的「IAM」頁面。
前往「IAM」頁面 - 選取專案。
-
在「主體」欄中,找出所有識別您或您所屬群組的資料列。如要瞭解自己所屬的群組,請與管理員聯絡。
- 針對指定或包含您的所有列,請檢查「角色」欄,確認角色清單是否包含必要角色。
授予角色
-
前往 Google Cloud 控制台的「IAM」頁面。
前往「IAM」頁面 - 選取專案。
- 按一下「Grant access」(授予存取權)。
-
在「New principals」(新增主體) 欄位中,輸入您的使用者 ID。 這通常是指 Google 帳戶的電子郵件地址。
- 按一下「選取角色」,然後搜尋角色。
- 如要授予其他角色,請按一下「Add another role」(新增其他角色),然後新增其他角色。
- 按一下「Save」(儲存)。
-
- 請確認您有足夠的配額,可使用 16 個 TPU Trillium (v6e) 晶片。在本教學課程中,您將使用需要 16 個晶片和隨選執行個體的節點集區設定。
- 確認您有 Docker 存放區。如果沒有,請在 Artifact Registry 中建立標準存放區。
準備環境
在本教學課程中,您將使用 Cloud Shell 管理 Google Cloud上代管的資源。Cloud Shell 已預先安裝本教學課程所需的軟體,包括 kubectl 和 Google Cloud CLI。
如要使用 Cloud Shell 設定環境,請按照下列步驟操作:
在 Google Cloud 控制台中啟動 Cloud Shell 工作階段,然後按一下「啟用 Cloud Shell」
。這項操作會在 Google Cloud 控制台的底部窗格啟動工作階段。
設定預設環境變數:
gcloud config set project PROJECT_ID gcloud config set billing/quota_project PROJECT_ID export PROJECT_ID=$(gcloud config get project) export CLUSTER_NAME=CLUSTER_NAME export REGION=CONTROL_PLANE_LOCATION export ZONE=ZONE export GCS_BUCKET_NAME=BUCKET_NAME替換下列值:
PROJECT_ID:您的 Google Cloud 專案 ID。CLUSTER_NAME:GKE 叢集的名稱。CONTROL_PLANE_LOCATION:GKE 叢集和 TPU 節點所在的 Compute Engine 區域。該區域必須包含提供 TPU Trillium (v6e) 機型的可用區。ZONE:所選CONTROL_PLANE_LOCATION區域內的可用區,提供 TPU Trillium (v6e) 機型。如要列出提供 TPU Trillium (v6e) TPU 的區域,請執行下列指令:gcloud compute accelerator-types list --filter="name~ct6e" --format="value(zone)"BUCKET_NAME:包含訓練資料的 Cloud Storage bucket 名稱。
複製範例存放區:
git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git cd kubernetes-engine-samples前往工作目錄:
cd ai-ml/llm-training-jax-tpu-gemma3
建立及設定 Google Cloud 資源
在本節中,您將建立及設定 Google Cloud 資源。
建立 GKE 叢集
您可以在 GKE Autopilot 或 Standard 叢集上,使用 TPU 微調 LLM。建議您使用 Autopilot 叢集,享受全代管的 Kubernetes 體驗。如要為工作負載選擇最合適的 GKE 作業模式,請參閱「選擇 GKE 作業模式」。
Autopilot
建立使用 Workload Identity Federation for GKE,並啟用 Cloud Storage FUSE 的 GKE Autopilot 叢集。
gcloud container clusters create-auto ${CLUSTER_NAME} \
--location=${REGION}
建立叢集可能需要幾分鐘的時間。
標準
建立使用 Workload Identity Federation for GKE 的地區 GKE Standard 叢集,並啟用 Cloud Storage FUSE。
gcloud container clusters create ${CLUSTER_NAME} \ --enable-ip-alias \ --addons GcsFuseCsiDriver \ --machine-type=n2-standard-4 \ --num-nodes=2 \ --workload-pool=${PROJECT_ID}.svc.id.goog \ --location=${REGION}建立叢集可能需要幾分鐘的時間。
建立單一主機節點集區:
gcloud container node-pools create jax-tpu-nodepool \ --cluster=${CLUSTER_NAME} \ --machine-type=ct6e-standard-1t \ --num-nodes=1 \ --location=${REGION} \ --node-locations=${ZONE} \ --workload-metadata=GKE_METADATA
GKE 會建立具有 1x1 拓撲和一個節點的 TPU Trillium 節點集區。--workload-metadata=GKE_METADATA 旗標會將節點集區設定為使用 GKE 中繼資料伺服器。
安裝 JobSet
設定
kubectl,與叢集通訊:gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}安裝最新發布的 JobSet 版本:
kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml將
JOBSET_VERSION替換為 JobSet 的最新發布版本。例如:v0.11.0。驗證 JobSet 安裝狀態:
kubectl get pods -n jobset-system輸出結果會與下列內容相似:
NAME READY STATUS RESTARTS AGE jobset-controller-manager-6c56668494-l4dhc 1/1 Running 0 4m45s如果 JobSet 正在等待資源,您可能需要新增更多節點。
設定 Cloud Storage FUSE
如要微調 LLM,請提供訓練資料。在本教學課程中,您將使用 Hugging Face 的 TinyStories 資料集。這個資料集包含由 GPT-3.5 和 GPT-4 合成生成,且使用有限詞彙的短篇故事。
本節說明如何設定 Cloud Storage FUSE,從 Cloud Storage 值區讀取資料。
下載資料集:
wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt將資料上傳至新的 Cloud Storage 值區:
gcloud storage buckets create gs://${GCS_BUCKET_NAME} \ --location=${REGION} \ --enable-hierarchical-namespace \ --uniform-bucket-level-access gcloud storage cp TinyStories-train.txt gs://${GCS_BUCKET_NAME}如要允許工作負載透過 Cloud Storage FUSE 讀取資料,請建立 Kubernetes 服務帳戶 (KSA) 並新增必要權限。執行
permissionsetup.sh指令碼:執行這項指令碼後,系統會在Google Cloud 專案和 GKE 叢集中設定下列資源:
- 系統會在專案中建立名為
gcs-fuse-sa的新 IAM 服務帳戶。 - 建立的 Google Cloud 服務帳戶 (GSA) (
gcs-fuse-sa) 會在${GCS_BUCKET_NAME}指定的 Cloud Storage 儲存空間上獲得roles/storage.objectViewer角色。這項權限可讓 GSA 從值區讀取物件。 - 系統會在 GKE 叢集的
default命名空間中,建立名為jaxserviceaccount的新 KSA。 - 更新 GSA 的 IAM 政策,將
roles/iam.workloadIdentityUser角色授予 KSA。這項權限允許 KSA 模擬 GSA。 系統會註解 KSA,將其連結至 GSA。這項註解會告知 GKE,KSA 應使用 Workload Identity 模擬哪個 GSA。
現在,凡是在 GKE 叢集的
default命名空間中執行的 Pod,只要使用jaxserviceaccount服務帳戶,就能以gcs-fuse-saGSA 身分進行驗證。這些 Pod 將具備gs://${GCS_BUCKET_NAME}bucket 中儲存物件的讀取權,這對於微調工作使用 Cloud Storage FUSE 存取資料集至關重要。
- 系統會在專案中建立名為
建立微調指令碼
在本節中,您將瞭解訓練指令碼,該指令碼會對 Gemma 3 模型執行微調作業。這個指令碼使用 Gemma3Tokenizer。
請查看下列Gemma3LLMTrain.py微調指令碼:
這段指令碼適用下列情況:
Gemma3Tokenizer會將文字資料轉換為模型可處理的權杖。load_and_preprocess_data函式會從檔案讀取訓練資料,將資料分割為個別故事,並使用權杖化工具將文字轉換為權杖的填補序列。generate_text函式會採用模型、模型參數和提示來生成文字。train_step函式會定義單一訓練疊代,包括正向傳遞、損失計算 (使用交叉熵)、梯度計算和參數更新。train_model函式會針對指定數量的訓練週期,逐一處理資料集,並針對每個批次呼叫train_step函式。run_training函式會協調整個程序,包括載入資料、初始化 Gemma 3 模型 (Gemma3_270M) 和最佳化工具、載入預先訓練的參數、設定資料分片以進行平行處理、執行測試生成作業、執行訓練迴圈,以及執行最終的文字生成作業,以展示微調效果。- 這個指令碼會使用
argparse程式庫,接受maxlen、batch_size和datacount參數的指令列引數。
您已瞭解微調指令碼,現在可以將其容器化,以便在 GKE 上執行。
將微調指令碼容器化
在 GKE 叢集中執行微調指令碼前,您需要先將其容器化。本教學課程使用 JAX AI 圖像做為基礎圖像。
在
Gemma3LLMTrain.py檔案所在的目錄中開啟Dockerfile:這個 Dockerfile 會安裝必要的依附元件,並將
Gemma3LLMTrain.py檔案複製到容器中。建構 Docker 映像檔並推送至映像檔存放區:
export REPOSITORY=REPOSITORY_NAME export IMAGE_NAME="jax-gemma3-training" export IMAGE_TAG="latest" export DOCKERFILE_PATH="./Dockerfile" export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}" docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" . gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q docker push "${IMAGE_URI}"將
REPOSITORY_NAME替換為 Artifact Registry 存放區的名稱。將角色繫結新增至服務帳戶:
export PROJECT_NUMBER=$(gcloud projects describe $PROJECT_ID --format 'get(projectNumber)') gcloud artifacts repositories add-iam-policy-binding ${REPOSITORY} \ --project=${PROJECT_ID} \ --location=${REGION} \ --member="serviceAccount:${PROJECT_NUMBER}-compute@developer.gserviceaccount.com" \ --role="roles/artifactregistry.reader"
映像檔已儲存至存放區,現在可以將微調作業部署至 GKE 叢集。
部署 LLM 微調工作
本節說明如何將 LLM 微調工作部署至 GKE 叢集。
開啟
training_singlehost.yaml資訊清單:套用資訊清單:
envsubst < training_singlehost.yaml | kubectl apply -f -
GKE 會建立 Job,在 TPU Trillium (v6e) 節點上啟動 Pod。這個 Pod 會執行 Python 微調指令碼,並使用 Cloud Storage FUSE,從掛接在 /data 路徑的指定 Cloud Storage 值區存取微調資料。然後指令碼會微調 Gemma 模型。
監控訓練工作
在本節中,您將監控微調工作的進度和效能。
查看微調進度
列出 Pod:
# Find the Pods kubectl get pods追蹤記錄輸出:
kubectl logs -f pods/POD_NAME將
POD_NAME替換為 Pod 名稱。輸出結果會與下列內容相似:
Global device count: 1 Batch size: 128, Max length: 256, Data count: 96000 I1028 00:12:55.925999 1387 google_auth_provider.cc:181] Running on GCE, using service account ... Generating response for: Once upon a time, there was a girl named Amy. Response: Amy lived in a small house. The house was in a big field. Amy liked to play in the big field. She Start training model Loss after batch 0: 10.25 Loss after batch 10: 4.3125 . . . Loss after batch 740: 1.41406 Completed training model. Total time for training 294.6791355609894 seconds Generating response for: Once upon a time, there was a girl named Amy. Response: She loved to play with her toys. One day, Amy's mom told her that she had to go to the store to分析輸出內容:
Global device count: 1行表示使用的 TPU 核心。- 模型會從預先訓練的檢查點載入,因此在微調執行前會產生合理的文字。
- 微調後產生的輸出內容與短篇故事的開頭更為相似,表示模型正在從新資料集學習。
- 對完整資料集進行微調,應可產生更精確的輸出內容。
觀察指標
查看 TPU 和 CPU 指標,瞭解微調作業的效能。如要查看叢集的觀測指標,請按照「查看叢集和工作負載觀測指標」一文中的步驟操作。
其他微調設定
本節說明微調工作負載的替代設定。
多種模型供您選擇
本教學課程使用 Gemma3_270M 模型,這個小型模型適合單一主機 TPU Trillium (v6e) 節點集區。如果大型模型需要更多記憶體和運算資源才能微調,可以使用多主機或多切片節點集區設定。
如需可用模型的完整清單,請參閱 Gemma 說明文件。
節點集區設定
本教學課程使用單一主機節點集區。您也可以視需求建立多主機 TPU 配量節點集區或多配量節點集區。
以下分頁說明如何為多主機和多切片節點集區建立節點集區:
多主機
在 Cloud Shell 中執行下列指令:
gcloud container node-pools create jax-tpu-multihost1 \ --cluster=${CLUSTER_NAME} \ --machine-type=ct6e-standard-4t \ --num-nodes=2 \ --tpu-topology=2x4 \ --location=${REGION} \ --node-locations=${ZONE}GKE 會建立具有
2x4拓撲和兩個節點的 TPU Trillium 節點集區。開啟
training_multihost_jobset.yaml工作定義:部署微調作業:
envsubst < training_multihost_jobset.yaml | kubectl apply -f -
多配量
在 Cloud Shell 中執行下列指令:
gcloud container node-pools create jax-tpu-multihost1 \ --cluster=${CLUSTER_NAME} \ --machine-type=ct6e-standard-4t \ --num-nodes=2 \ --tpu-topology=2x4 \ --location=${REGION} \ --node-locations=${ZONE} gcloud container node-pools create jax-tpu-multihost2 \ --cluster=${CLUSTER_NAME} \ --machine-type=ct6e-standard-4t \ --num-nodes=2 \ --tpu-topology=2x4 \ --location=${REGION} \ --node-locations=${ZONE}GKE 會建立兩個 TPU Trillium 節點集區。每個節點集區都有
2x4拓撲和兩個節點。開啟
training_multislice_jobset.yaml工作定義:部署微調作業:
envsubst < training_multislice_jobset.yaml | kubectl apply -f -
效能分析與最佳化
如要分析及最佳化機器學習微調作業的效能,可以使用 XProf。XProf 是一套工具,可剖析及檢查以 JAX、TensorFlow 或 PyTorch/XLA 建構的機器學習工作負載。XProf 會顯示執行追蹤記錄、記憶體用量和其他資料,方便您微調模型和訓練設定,進而提高效率及加快訓練速度。
如要使用 XProf 分析微調工作負載的效能,請完成本節中的下列步驟:
- 安裝
xprof套件。修改訓練指令碼,啟動 XProf 伺服器。 - 修改 Kubernetes 工作資訊清單,加入 XProf 記錄的磁碟區掛接。
- 授予服務帳戶權限,將 XProf 記錄寫入 Cloud Storage bucket。
- 在 Pod 中執行 XProf,並設定通訊埠轉送,存取 XProf 資訊主頁。
安裝 XProf 套件
前往包含 XProf 範例的目錄:
cd ai-ml/llm-training-jax-tpu-gemma3/xprof-enabled建構 Docker 映像檔並推送至映像檔存放區:
export REPOSITORY=REPOSITORY_NAME export IMAGE_NAME="jax-gemma3-training-xp" export IMAGE_TAG="latest" export DOCKERFILE_PATH="./Dockerfile" export IMAGE_URI="${REGION}-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${IMAGE_NAME}:${IMAGE_TAG}" docker build -t "${IMAGE_URI}" -f "${DOCKERFILE_PATH}" . gcloud auth configure-docker "${REGION}-docker.pkg.dev" -q docker push "${IMAGE_URI}"將
REPOSITORY_NAME替換為 Artifact Registry 存放區的名稱。執行
Dockerfile指令碼:這個 Dockerfile 會安裝 XProf 依附元件。
將微調指令碼複製到容器中
在本節中,請建立並套用 Kubernetes Job 資訊清單,其中包含 XProf 記錄檔的必要磁碟區掛接。
開啟
training_singlehost.yaml工作定義:套用資訊清單:
envsubst < training_singlehost.yaml | kubectl apply -f -
授予服務帳戶寫入 XProf 記錄的權限
如要讓服務帳戶能夠寫入及讀取資料,請新增
"roles/storage.objectUser"角色:export GSA_NAME="GSA_NAME" # Same as used in initial setup # Automatically get the current project ID export PROJECT_ID=$(gcloud config get-value project) # Cloud Storage Bucket details export XPROF_GCS_BUCKET_NAME="XPROF_GCS_BUCKET_NAME" # Derived Variables export GSA_EMAIL="${GSA_NAME}@${PROJECT_ID}.iam.gserviceaccount.com" gcloud storage buckets add-iam-policy-binding "gs://${XPROF_GCS_BUCKET_NAME}" \ --member="serviceAccount:${GSA_EMAIL}" \ --role="roles/storage.objectUser" \ --project="${PROJECT_ID}"更改下列內容:
GSA_NAME:要授予角色的 Google 服務帳戶名稱。XPROF_GCS_BUCKET_NAME:要授予角色的 bucket 名稱。
在 Pod 中執行 XProf:
kubectl exec POD_NAME -c training-container -it -- bash # exec into the container xprof --port 9001 --logdir /xprof # start xprof將
POD_NAME替換為 Pod 名稱。
存取 XProf 資訊主頁
將通訊埠轉送設定為 Pod 中的 XProf 伺服器:
kubectl port-forward POD_NAME 9001:9001在瀏覽器的網址列中輸入下列內容:
http://localhost:9001/XProf Trace Viewer 隨即開啟。
在 TensorBoard 視窗中,按一下「擷取設定檔」。
在「Profile Service URL(s) or TPU name」(剖析服務網址或 TPU 名稱) 欄位中輸入
localhost:9002。如要擷取更多詳細資料,請在「主機追蹤 (TraceMe) 層級」中選取「詳細」,並啟用 Python 追蹤記錄功能。
如要查看資訊主頁,請按一下「擷取」。
TensorBoard 會擷取設定檔,方便您分析訓練指令碼的效能。圖表會顯示 TPU 和 CPU 效能設定檔的執行時間軸:
如要瞭解更多分析訓練工作負載效能的剖析選項,請參閱 JAX 說明文件中的「剖析運算」一節。
在實際工作環境中微調
本教學課程說明如何在分散式環境中測試以 JAX 為基礎的訓練作業。如要在正式環境中微調 LLM,請使用 Maxtext 程式庫。如果您對擴散模型感興趣,請使用 Maxdiffusion 實作項目。
如要在正式環境中執行長時間的訓練或微調工作負載,請設定工作負載查核點,盡量減少失敗時的進度損失。如要進一步瞭解如何設定多層級檢查點,請參閱使用多層級檢查點在 GKE 上訓練大規模機器學習模型。
清除所用資源
為避免因為本教學課程所用資源,導致系統向 Google Cloud 帳戶收取費用,請刪除含有相關資源的專案,或者保留專案但刪除個別資源。
刪除個別資源
如要避免系統向您的 Google Cloud 帳戶收取本教學課程所用資源的費用,請刪除含有相關資源的專案,或者保留專案但執行下列指令刪除個別資源:
刪除您在本教學課程中建立的資源:
gcloud container clusters delete ${CLUSTER_NAME} --location=${REGION} gcloud storage rm --recursive gs://${GCS_BUCKET_NAME} gcloud artifacts docker images delete ${IMAGE_URI} --delete-tags如果不需要 XProf 產生的資料,請移除 XProf 使用的 Cloud Storage bucket:
gcloud storage rm --recursive gs://${XPROF_GCS_BUCKET_NAME}
後續步驟
- 進一步瞭解 GKE 中的 TPU。
- 探索 JAX 存放區。