使用 TPU v6e 訓練模型

本文將逐步說明如何在 Cloud TPU v6e (又稱 Trillium) 上訓練模型,涵蓋環境設定、效能最佳化,以及使用 JAX 和 PyTorch/XLA 的實用訓練範例。

TPU v6e (又稱 Trillium) 是 Google 的第 6 代 TPU。在所有技術介面 (例如 API 和記錄) 和本文件中,Trillium 都會稱為 v6e。每個 Pod 都有 256 個晶片,因此 TPU v6e 的架構與 v5e 有許多相似之處。TPU v6e 經過最佳化,適用於 Transformer、文字轉圖像和卷積類神經網路 (CNN) 的訓練、微調和服務。如要進一步瞭解 TPU v6e 系統架構和設定,請參閱「TPU v6e」。

如要瞭解如何在 Cloud TPU v6e 上執行推論作業,請參閱下列教學課程:

事前準備

開始前,請先完成下列事項:

  • 建立 Google Cloud 帳戶和專案並啟用計費功能
  • 安裝 Google Cloud CLI Alpha 版元件
  • 啟用 Cloud TPU API
  • 建立 Cloud TPU 服務代理
  • 建立 Cloud TPU 服務帳戶並授予權限

詳情請參閱「設定 Cloud TPU 環境」。

確認配額和權限

確認專案具有下列配額:

如果您使用 Google Kubernetes Engine (GKE) 和 XPK (加速處理套件),則需要在 Google Cloud 控制台中取得額外權限。詳情請參閱「在 Google Cloud 控制台中需要的權限 」。

佈建選項

您可以使用下列方法佈建及管理 TPU v6e:

  • GKE:您可以透過 GKE 佈建及管理 TPU,做為容器化機器學習工作負載的加速器集區。詳情請參閱「GKE 中的 TPU 簡介」。
  • GKE 和 XPK:XPK 是一項指令列工具,可簡化在 GKE 上建立叢集和執行工作負載的作業。這項服務專為機器學習從業人員設計,可讓他們佈建 TPU 並執行訓練作業,不必具備深厚的 Kubernetes 專業知識。詳情請參閱 XPK GitHub 存放區
  • Cloud TPU 佇列資源:您可以要求佇列資源,系統會在資源可用時佈建 TPU 容量。非常適合批次工作和容錯工作負載,這些工作可以排隊等待。你可以為要求指定時間範圍。詳情請參閱「管理已加入佇列的資源」。

透過 GKE 和 XPK 佈建 v6e TPU

如果您使用 GKE 和 v6e TPU,可以透過 Kubernetes 指令或 XPK 佈建 TPU,並訓練或提供模型。如要進一步瞭解如何搭配使用 GKE 和 TPU,請參閱「GKE 中的 TPU 簡介」。

Cloud TPU v6e 支援網路介面卡 (NIC) 設定,可讓您跨多個網路擴充輸送量。以下各節提供指令,說明如何使用 XPK 建立支援單一 NIC 或多個 NIC 的 GKE 叢集。對於大多數單一分割區工作負載,單一 NIC 可提供足夠的效能,且設定較少。如果是 Multislice 工作負載和需要高資料擷取速度的工作負載,請使用多個 NIC

使用 XPK 建立支援單一 NIC 的叢集

對於大多數單一 Slice 工作負載,單一 NIC 提供的效能已足夠,且設定較少。如果是 Multislice 工作負載和需要高資料擷取速度的工作負載,請使用多個 NIC

下列各節說明如何使用 XPK 建立支援單一 NIC 的 GKE 叢集。

安裝 XPK 並設定環境變數

  1. 安裝 XPK。請按照 XPK GitHub 存放區中的操作說明進行。

  2. 為叢集設定環境變數:

    export CLUSTER_NAME=XPK_CLUSTER_NAME
    export ZONE=us-east1-d
    export PROJECT_ID=PROJECT_ID
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=1

    請設定下列環境變數:

    • CLUSTER_NAME:叢集名稱。
    • ZONE:要建立 TPU 叢集的可用區。如要進一步瞭解支援的區域,請參閱「地區和區域」。
    • PROJECT_ID: Google Cloud 專案 ID。
    • ACCELERATOR_TYPE:TPU 類型 (也稱為加速器類型) 會指定要建立的 Cloud TPU 版本和大小。例如:v6e-256。如要進一步瞭解各 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    • NUM_SLICES:叢集的 TPU 節點數量。每個切片都有 ACCELERATOR_TYPE 中指定的晶片數量。如果是單一切片叢集,請將 NUM_SLICES 設為 1。如果是 Multislice 叢集,請根據工作負載的可擴充性需求指定切片數量。叢集中的晶片總數是 ACCELERATOR_TYPE 中的晶片數乘以 NUM_SLICES

建立叢集

選擇下列其中一個選項來建立叢集。建議使用 MTU 為 8,896 的自訂網路,以獲得最佳效能。詳情請參閱「設定 MTU」。

自訂網路

如要建立 MTU 為 8,896 的自訂網路,並用於叢集,請按照下列步驟操作:

  1. 為網路和防火牆名稱設定環境變數:

    export NETWORK_NAME=NETWORK_NAME
    export NETWORK_FW_NAME=FIREWALL_NAME

    更改下列內容:

    • NETWORK_NAME:網路名稱。
    • FIREWALL_NAME:網路防火牆規則的名稱。
  2. 建立 MTU 為 8,896 的自訂網路:

    gcloud compute networks create ${NETWORK_NAME} \
        --mtu=8896 \
        --project=${PROJECT_ID} \
        --subnet-mode=auto \
        --bgp-routing-mode=regional
  3. 建立防火牆規則,允許網路上的 TCP、ICMP 和 UDP 流量:

    gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
        --network=${NETWORK_NAME} \
        --allow tcp,icmp,udp \
        --project=${PROJECT_ID}
  4. 為 XPK 叢集引數設定環境變數,以使用您建立的網路:

    export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
  5. 建立 XPK 叢集。下列指令會佈建隨需容量:

    xpk cluster create --cluster=${CLUSTER_NAME} \
        --cluster-cpu-machine-type=e2-standard-8 \
        --num-slices=${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --on-demand \
        --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"

    如要使用預留容量,請將 --on-demand 替換為 --reservation=RESERVATION_NAME。如要使用 TPU Spot VM,請將 --on-demand 替換為 --spot

預設網路

如果您不需要高 MTU 網路,可以建立使用預設虛擬私有雲網路的叢集。下列指令會佈建隨需容量:

xpk cluster create --cluster=${CLUSTER_NAME} \
    --cluster-cpu-machine-type=e2-standard-8 \
    --num-slices=${NUM_SLICES} \
    --tpu-type=${ACCELERATOR_TYPE} \
    --zone=${ZONE} \
    --project=${PROJECT_ID} \
    --on-demand

如要使用預留容量,請將 --on-demand 替換為 --reservation=RESERVATION_NAME。如要使用 TPU Spot VM,請將 --on-demand 替換為 --spot

使用 XPK 建立支援多重 NIC 的叢集

如果是 Multislice 工作負載或其他需要高網路頻寬的工作負載 (例如資料擷取),您可以使用多個 NIC 來提升效能。使用多個 NIC 時,每個 TPU VM 都會分配到額外的網路介面,且每個介面都會連線至專屬的 VPC 網路,進而提升整體網路輸送量。對於大多數單一工作負載,單一 NIC 提供的效能已足夠,且設定較少。

下列各節說明如何使用 XPK 建立支援多重 NIC 的 GKE 叢集。

安裝 XPK 並設定環境變數

  1. 安裝 XPK。請按照 XPK GitHub 存放區中的操作說明進行。

  2. 為叢集和主要網路設定環境變數:

    export CLUSTER_NAME=XPK_CLUSTER_NAME
    export REGION=REGION
    export ZONE=us-east1-d
    export PROJECT_ID=PROJECT_ID
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=2
    
    export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
    export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
    export FIREWALL_RULE_NAME_1=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
    export ROUTER_NAME_1=${CLUSTER_NAME}-network-1-${ZONE}
    export NAT_CONFIG_1=${CLUSTER_NAME}-natconfig-1-${ZONE}
    
    export NETWORK_NAME_2=${CLUSTER_NAME}-mtu9k-2-${ZONE}
    export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
    export FIREWALL_RULE_NAME_2=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
    export ROUTER_NAME_2=${CLUSTER_NAME}-network-2-${ZONE}
    export NAT_CONFIG_2=${CLUSTER_NAME}-natconfig-2-${ZONE}

    請設定下列環境變數:

    • CLUSTER_NAME:叢集名稱。
    • REGION:要建立 TPU 叢集的區域。
    • ZONE:要建立 TPU 叢集的可用區。如要進一步瞭解支援的區域,請參閱「地區和區域」。
    • PROJECT_ID: Google Cloud 專案 ID。
    • ACCELERATOR_TYPE:加速器類型會指定您要建立的 Cloud TPU 版本和大小。例如,v6e-256。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    • NUM_SLICES:叢集的 TPU 節點數量。每個切片都有 ACCELERATOR_TYPE 中指定的晶片數量。如果是單一切片叢集,請將 NUM_SLICES 設為 1。如果是 Multislice 叢集,請根據工作負載的可擴充性需求指定切片數量。叢集中的晶片總數是 ACCELERATOR_TYPE 中的晶片數乘以 NUM_SLICES

建立主要網路資源

  1. 建立最大傳輸單位 (MTU) 為 8,896 的主要網路:

    gcloud compute networks create ${NETWORK_NAME_1} \
        --mtu=8896 \
        --bgp-routing-mode=regional \
        --subnet-mode=custom \
        --project=${PROJECT_ID}

    使用 MTU 為 8,896 的自訂網路可提升效能。詳情請參閱「設定 MTU」。

  2. 建立主要子網路:

    gcloud compute networks subnets create ${SUBNET_NAME_1} \
        --network=${NETWORK_NAME_1} \
        --range=10.11.0.0/18 \
        --region=${REGION} \
        --project=${PROJECT_ID}
  3. 為主要網路建立防火牆規則,允許主要網路上的 tcpicmpudp 流量:

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME_1} \
        --network=${NETWORK_NAME_1} \
        --allow tcp,icmp,udp \
        --project=${PROJECT_ID}
  4. 為主要網路建立 Cloud Router:

    gcloud compute routers create ${ROUTER_NAME_1} \
        --project=${PROJECT_ID} \
        --network=${NETWORK_NAME_1} \
        --region=${REGION}
  5. 為主要網路設定 NAT。下列指令可讓叢集的流量連上網際網路:

    gcloud compute routers nats create ${NAT_CONFIG_1} \
        --router=${ROUTER_NAME_1} \
        --region=${REGION} \
        --auto-allocate-nat-external-ips \
        --nat-all-subnet-ip-ranges \
        --project=${PROJECT_ID} \
        --enable-logging

建立次要網路資源

  1. 建立次要網路:

    gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
        --bgp-routing-mode=regional \
        --subnet-mode=custom \
        --project=${PROJECT_ID}
    
  2. 為次要網路建立子網路:

    gcloud compute networks subnets create ${SUBNET_NAME_2} \
        --network=${NETWORK_NAME_2} \
        --range=10.10.0.0/18 \
        --region=${REGION} \
        --project=${PROJECT_ID}
    
  3. 建立防火牆規則,允許新網路中的流量:

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME_2} \
        --network=${NETWORK_NAME_2} \
        --allow tcp,icmp,udp \
        --source-ranges 10.10.0.0/18 \
        --project=${PROJECT_ID}
    
  4. 為次要網路建立 Cloud Router:

    gcloud compute routers create ${ROUTER_NAME_2} \
        --project=${PROJECT_ID} \
        --network=${NETWORK_NAME_2} \
        --region=${REGION}
    
  5. 為 Cloud Router 建立 NAT 設定:

    gcloud compute routers nats create ${NAT_CONFIG_2} \
        --router=${ROUTER_NAME_2} \
        --region=${REGION} \
        --auto-allocate-nat-external-ips \
        --nat-all-subnet-ip-ranges \
        --project=${PROJECT_ID} \
        --enable-logging
    

建立叢集

  1. 為叢集和節點集區引數設定環境變數,以使用您建立的網路和子網路:

    export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
    export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
    

    這些引數會將叢集設定為使用您建立的兩個網路,以支援多重 NIC。

  2. 建立叢集。下列指令會佈建隨需容量:

    xpk cluster create \
        --cluster=${CLUSTER_NAME} \
        --cluster-cpu-machine-type=e2-standard-8 \
        --num-slices=${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --zone=${ZONE}  \
        --project=${PROJECT_ID} \
        --on-demand \
        --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
        --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
        --create-vertex-tensorboard

    如要使用預留容量,請將 --on-demand 替換為 --reservation=RESERVATION_NAME。如要使用 TPU Spot VM,請將 --on-demand 替換為 --spot

驗證多 NIC 設定

建立支援多個 NIC 的叢集後,您可以建立 XPK 工作負載並新增 --command ifconfig 旗標,驗證兩個 NIC 是否都正在使用中。

  1. 使用下列指令,在Google Cloud 控制台記錄中顯示 ifconfig 指令的輸出內容。您必須指定 --base-docker-image maxtext_base_image 旗標來使用 MaxText 基礎映像檔 (如下列範例所示),或指定 --docker-image 旗標和要使用的映像檔。

    xpk workload create \
        --cluster ${CLUSTER_NAME} \
        --base-docker-image maxtext_base_image \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --command "ifconfig"

    如要啟用偵錯記錄或使用 Vertex AI TensorBoard,請在指令中加入下列選用引數:

    --enable-debug-logs \
    --use-vertex-tensorboard
  2. 在 Google Cloud 控制台記錄中檢查 XPK 工作負載的輸出內容,確認 eth0 和 eth1 的 MTU 都設為 8,896。

設定 JAX 或 PyTorch

下列資源說明如何在 TPU 上設定 JAX 或 PyTorch,具體做法取決於您使用的佈建和管理方法:

如要使用 MaxText 設定及執行 XPK,請參閱「使用 XPK 大規模執行 MaxText 」。

改善 TCP 設定

如果您使用佇列資源佈建 v6e TPU,可以執行下列指令,提高 TCP 接收緩衝區限制,進而提升網路效能。

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
    --project "${PROJECT_ID}" \
    --zone "${ZONE}" \
    --node=all \
    --worker=all \
    --command='
    sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'

使用 SkyPilot

您可以搭配 SkyPilot 使用 Cloud TPU v6e。SkyPilot 是開放原始碼架構,可簡化執行、管理及調度 AI 工作負載的程序。你可以在 SkyPilot 中新增 v6e 相關位置和定價資訊。 詳情請參閱 SkyPilot TPU v6e 範例

訓練範例

下列各節提供在 Cloud TPU v6e 上訓練 MaxText、MaxDiffusion 和 PyTorch 模型的範例。

這些範例已使用下列軟體版本測試:

  • Python 3.10 以上版本
  • 夜間軟體版本:
    • 每晚 JAX 0.4.32.dev20240912
    • 每晚 LibTPU 0.1.dev20240912+nightly
  • 穩定版軟體:
    • JAX + JAX Lib of v0.4.37

在 Cloud TPU v6e 上訓練 MaxText 和 MaxDiffusion

以下各節將說明 MaxTextMaxDiffusion 模型的訓練生命週期。

一般來說,高階步驟如下:

  1. 建構工作負載基本映像檔。
  2. 使用 XPK 執行工作負載。
    1. 為工作負載建構訓練指令。
    2. 部署工作負載。
  3. 追蹤工作負載並查看指標。
  4. 如不需要,請刪除 XPK 工作負載。
  5. 不再需要叢集時,請將其刪除。

建構基本映像檔

安裝 MaxText 或 MaxDiffusion,然後建構 Docker 映像檔:

  1. 複製要使用的存放區,然後變更為存放區的目錄:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    
  2. 將 Docker 設為使用 Google Cloud CLI:

    gcloud auth configure-docker
    
  3. 使用下列指令或 JAX AI 圖片建構 Docker 映像檔。 如要進一步瞭解 JAX AI 圖片,請參閱「JAX AI 圖片」。

    MaxText:

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    

    MaxDiffusion:

    bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
    
  4. 在有效的 gcloud CLI 設定中設定專案 ID:

    gcloud config set project ${PROJECT_ID}
    
  5. 如果從沒有在本機建構映像檔的機器啟動工作負載,請上傳映像檔。

    1. 設定 CLOUD_IMAGE_NAME 環境變數:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. 上傳圖片:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

使用 XPK 執行工作負載

  1. 如未使用 MaxText 設定的預設值MaxDiffusion,請設定下列環境變數:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. 建構模型指令碼。這個指令碼會在後續步驟中複製為訓練指令。

    請先不要執行模型指令碼。

    MaxText

    MaxText 是以純 Python 和 JAX 編寫的開放原始碼 LLM,具備高效能和高擴充性,適用於 Google Cloud TPU 和 GPU,可進行訓練和推論。

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
         base_output_directory=${BASE_OUTPUT_DIR} \
         dataset_type=synthetic \
         per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
         enable_checkpointing=false \
         gcs_metrics=true \
         profiler=xplane \
         skip_first_n_steps_for_profiler=5 \
         steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma 是 Google DeepMind 開發的一系列開放權重 LLM,以 Gemini 研究和技術為基礎。

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral 是由 Mistral AI 開發的頂尖 AI 模型,採用稀疏混合專家 (MoE) 架構。

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama 是 Meta 開發的一系列開放權重 LLM。

    如需瞭解如何在 PyTorch 上執行 Llama3,請參閱 torchprime 存放區中的 torch_xla 模型

    MaxDiffusion

    MaxDiffusion 是一系列以純 Python 和 JAX 編寫的各種延遲擴散模型參考實作,可在 XLA 裝置上執行,包括 Cloud TPU 和 GPU。Stable Diffusion 是潛在文字轉圖像模型,可根據任何文字輸入生成逼真的圖像。

    您需要安裝特定 Git 分支,才能執行 MaxDiffusion,如下列訓練指令碼所示。

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    && pip install -r requirements.txt && pip install .
    && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR}
    && python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR} \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash \
        run_name=sdxl-ddp-v6e
    
  3. 匯出下列變數:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    環境變數說明

    • CLUSTER_NAME:叢集名稱。
    • ACCELERATOR_TYPE:加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱 TPU 版本
    • NUM_SLICES:TPU 配量數量。
    • YOUR_MODEL_SCRIPT:要以訓練指令執行的模型指令碼。
  4. 使用上一個步驟建立的指令碼執行模型。您必須指定 --base-docker-image 旗標來使用 MaxText 基本映像檔,或是指定 --docker-image 旗標和要使用的映像檔。

    您可以選擇新增下列選用旗標:

    xpk workload create \
      --cluster ${CLUSTER_NAME} \
      {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
      --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
      --tpu-type=${ACCELERATOR_TYPE} \
      --num-slices=${NUM_SLICES}  \
      --on-demand \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --command="${YOUR_MODEL_SCRIPT}"

    輸出內容包含追蹤工作負載的連結。 開啟連結並點選「記錄」分頁,即可即時追蹤工作負載。

在 MaxText 上偵錯 JAX

使用補充 XPK 指令,診斷叢集或工作負載無法執行的原因:

使用 Vertex AI 監控 MaxText 上的 JAX

如要使用 TensorBoard,您的 Google Cloud 使用者帳戶必須具備aiplatform.user角色。執行下列指令來授予這個角色:

gcloud projects add-iam-policy-binding your-project-id \
   --member='user:your-email' \
   --role='roles/aiplatform.user'

透過 Vertex AI 管理的 TensorBoard 查看純量和設定檔資料。

  1. 將您使用的區域資源管理 (CRUD) 要求數從 600 提高至 5000。如果小型工作負載使用的 VM 少於 16 個,可能不會有問題。

  2. 安裝 Vertex AI 的依附元件,例如 cloud-accelerator-diagnostics

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. 如「建立 Vertex AI TensorBoard」一文所述,使用 --create-vertex-tensorboard 旗標建立叢集。您也可以在現有叢集上執行這項指令。

  4. 使用 --use-vertex-tensorboard 旗標和選用的 --experiment-name 旗標執行 XPK 工作負載時,請建立 Vertex AI 實驗。如需完整步驟清單,請參閱「建立 Vertex AI 實驗,將資料上傳至 Vertex AI TensorBoard」。

記錄包含 Vertex AI TensorBoard 的連結,類似於下列連結:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

您也可以在 Google Cloud 控制台中找到 Vertex AI TensorBoard 連結。 前往 Google Cloud 控制台中的 Vertex AI Experiments。從下拉式選單中選取適當的地區。

TensorBoard 目錄也會寫入您以 ${BASE_OUTPUT_DIR} 指定的 Cloud Storage bucket。

刪除 XPK 工作負載

使用 xpk workload delete 指令,根據工作前置字元或工作狀態刪除一或多個工作負載。如果您傳送了不再需要執行的 XPK 工作負載,或是工作停滯在佇列中,這個指令就非常實用。

刪除叢集

使用 xpk cluster delete 指令刪除叢集:

xpk cluster delete --cluster ${CLUSTER_NAME} \
    --zone=${ZONE} --project=${PROJECT_ID}

MaxDiffusion 基準化結果

我們在 v6e-4、v6e-16 和兩個 v6e-16 上執行 MaxDiffusion 的訓練指令碼。下表顯示測得的處理量。

v6e-4 v6e-16 兩個 v6e-16
訓練步驟 0.069 0.073 0.13
全域批次大小 8 32 64
處理量 (每秒範例數) 115.9 438.4 492.3

在 Cloud TPU v6e 上使用 PyTorch/XLA 訓練 Llama 模型

本節說明如何使用 PyTorch/XLA 在 Cloud TPU v6e 上,透過 WikiText 資料集訓練 Llama 模型。

存取 Hugging Face 和 Llama 3 模型

您需要 Hugging Face 使用者存取權杖才能執行這個範例。如要瞭解如何建立使用者存取權杖,請參閱 Hugging Face 使用者存取權杖說明文件

此外,您也需要取得 Hugging Face 上的 Llama-3-8B 模型存取權。如要取得存取權,請前往 HuggingFace 上的 Meta-Llama-3-8B 模型,然後要求存取權。

建立 Cloud TPU VM

在本範例中,請建立具有 8 個晶片的 Cloud TPU v6e。

  1. 設定環境變數:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east1-d
    export ACCELERATOR_TYPE=v6e-8
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    環境變數說明

    • PROJECT_ID: 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    • TPU_NAME:TPU 的名稱。
    • ZONE: 要建立 TPU VM 的可用區。如要進一步瞭解支援的區域,請參閱 TPU 地區和區域
    • ACCELERATOR_TYPE: 加速器類型會指定要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱 TPU 版本
    • RUNTIME_VERSION:Cloud TPU 軟體版本

  2. 建立 Cloud TPU VM:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --zone=${ZONE} \
       --project=${PROJECT_ID}

安裝

安裝 Hugging Face Transformers 的 pytorch-tpu/transformers 分支和依附元件。這個範例已使用下列依附元件版本進行測試:

  • torch:與 2.5.0 相容
  • torch_xla[tpu]:與 2.5.0 相容
  • jax:0.4.33
  • jaxlib:0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   sudo pip3 install -e .
   pip3 install datasets
   pip3 install evaluate
   pip3 install scikit-learn
   pip3 install accelerate
   pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
   pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

設定模型設定檔

下一節「執行模型」中的訓練指令會使用兩個 JSON 設定檔,定義模型參數和完全分片資料平行 (FSDP) 設定。FSDP 分片功能可將模型權重分散到多個 TPU,讓您在訓練時使用較大的批次大小。使用較小的模型進行訓練時,可能只要使用資料平行處理,並在每部裝置上複製權重即可。如要進一步瞭解如何在 PyTorch/XLA 中跨裝置分片張量,請參閱 PyTorch/XLA SPMD 使用指南

  1. 建立模型參數設定檔。以下是 Llama-3-8B 的模型參數設定。如要使用其他模型,請在 Hugging Face 尋找設定檔。舉例來說,請參閱 Llama-2-7B 設定

    cat > llama-config.json << EOF
    {
      "architectures": [
        "LlamaForCausalLM"
      ],
      "attention_bias": false,
      "attention_dropout": 0.0,
      "bos_token_id": 128000,
      "eos_token_id": 128001,
      "hidden_act": "silu",
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "model_type": "llama",
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "pretraining_tp": 1,
      "rms_norm_eps": 1e-05,
      "rope_scaling": null,
      "rope_theta": 500000.0,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.40.0.dev0",
      "use_cache": false,
      "vocab_size": 128256
    }
    EOF
    
  2. 建立 FSDP 設定檔:

    cat > fsdp-config.json << EOF
    {
      "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
      ],
      "xla": true,
      "xla_fsdp_v2": true,
      "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    如要進一步瞭解 FSDP,請參閱「使用 SPMD 的完全分片資料平行處理 」。

  3. 使用下列指令,將設定檔上傳至 Cloud TPU VM:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
       --worker=all \
       --project=${PROJECT_ID} \
       --zone=${ZONE}

執行模型

使用您在前一節建立的設定檔,執行 run_clm.py 指令碼,在 WikiText 資料集上訓練 Llama-3-8B 模型。訓練指令碼在 Cloud TPU v6e-8 上執行約需 10 分鐘。

  1. 在 Cloud TPU 上使用下列指令登入 Hugging Face:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       pip3 install "huggingface_hub[cli]"
       huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. 執行模型訓練:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       export PJRT_DEVICE=TPU
       export XLA_USE_SPMD=1
       export ENABLE_PJRT_COMPATIBILITY=true
       # Optional variables for debugging:
       export XLA_IR_DEBUG=1
       export XLA_HLO_DEBUG=1
       export PROFILE_EPOCH=0
       export PROFILE_STEP=3
       export PROFILE_DURATION_MS=100000
       # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
       export PROFILE_LOGDIR=PROFILE_PATH
       python3 transformers/examples/pytorch/language-modeling/run_clm.py \
         --dataset_name wikitext \
         --dataset_config_name wikitext-2-raw-v1 \
         --per_device_train_batch_size 16 \
         --do_train \
         --output_dir /home/$USER/tmp/test-clm \
         --overwrite_output_dir \
         --config_name /home/$USER/llama-config.json \
         --cache_dir /home/$USER/cache \
         --tokenizer_name meta-llama/Meta-Llama-3-8B \
         --block_size 8192 \
         --optim adafactor \
         --save_strategy no \
         --logging_strategy no \
         --fsdp "full_shard" \
         --fsdp_config /home/$USER/fsdp-config.json \
         --torch_dtype bfloat16 \
         --dataloader_drop_last yes \
         --flash_attention \
         --max_steps 20'

排解 PyTorch/XLA 問題

如果您在上一節中設定了用於偵錯的選用變數,模型的設定檔會儲存在變數 PROFILE_LOGDIR 指定的位置。您可以擷取儲存在這個位置的 xplane.pb 檔案,並使用 tensorboard 透過 TensorBoard 指示在瀏覽器中查看設定檔。

如果 PyTorch/XLA 未如預期運作,請參閱疑難排解指南,瞭解如何偵錯、分析及最佳化模型。