在 TPU 配量上執行 JAX 程式碼

執行本文中的指令前,請務必按照「設定帳戶和 Cloud TPU 專案」一文中的操作說明進行設定。

在單一 TPU 板上執行 JAX 程式碼後,您可以在 TPU 配量上執行程式碼,擴大規模。TPU 配量是指透過專用高速網路連線相互連結的多個 TPU 板。本文簡要介紹如何在 TPU 配量上執行 JAX 程式碼,如需更深入的資訊,請參閱「在多主機和多程序環境中使用 JAX」。

必要的角色

如要取得建立 TPU 並透過 SSH 連線所需的權限,請要求管理員在專案中授予您下列 IAM 角色:

如要進一步瞭解如何授予角色,請參閱「管理專案、資料夾和組織的存取權」。

您或許也能透過自訂角色或其他預先定義的角色,取得必要權限。

建立 Cloud TPU 節點

  1. 建立一些環境變數:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5litepod-32
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 使用 gcloud 指令建立 TPU 節點。舉例來說,如要建立 v5litepod-32 配量,請使用下列指令:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE}  \
        --version=${RUNTIME_VERSION} 

在切片上安裝 JAX

建立 TPU 配量後,您必須在 TPU 配量中的所有主機上安裝 JAX。如要執行這項操作,請使用 gcloud compute tpus tpu-vm ssh 指令,並搭配 --worker=all--commamnd 參數。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

在配量上執行 JAX 程式碼

如要在 TPU 配量上執行 JAX 程式碼,您必須在 TPU 配量中的每個主機上執行程式碼jax.device_count() 呼叫會停止回應,直到在切片中的每個主機上呼叫為止。以下範例說明如何在 TPU 節點上執行 JAX 計算。

準備程式碼

您需要 gcloud 344.0.0 以上版本 (適用於 scp 指令)。使用 gcloud --version 檢查 gcloud 版本,並視需要執行 gcloud components upgrade

建立名為 example.py 的檔案,並加入下列程式碼:


import jax

# Initialize the slice
jax.distributed.initialize()

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

example.py 複製到配量中的所有 TPU 工作站 VM

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

如果您先前未使用 scp 指令,可能會看到類似下列內容的錯誤訊息:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

如要解決錯誤,請執行錯誤訊息中顯示的 ssh-add 指令,然後重新執行指令。

在切片上執行程式碼

在每個 VM 上啟動 example.py 程式:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

輸出內容 (使用 v5litepod-32 切片產生):

global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]

清除所用資源

使用完 TPU VM 後,請按照下列步驟清除資源。

  1. 刪除 Cloud TPU 和 Compute Engine 資源。

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  2. 執行 gcloud compute tpus execution-groups list,確認資源已刪除。刪除作業可能需要幾分鐘才能完成。下列指令的輸出內容不應包含本教學課程中建立的任何資源:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}