在 TPU 配量上執行 JAX 程式碼
執行本文中的指令前,請務必按照「設定帳戶和 Cloud TPU 專案」一文中的操作說明進行設定。
在單一 TPU 板上執行 JAX 程式碼後,您可以在 TPU 配量上執行程式碼,擴大規模。TPU 配量是指透過專用高速網路連線相互連結的多個 TPU 板。本文簡要介紹如何在 TPU 配量上執行 JAX 程式碼,如需更深入的資訊,請參閱「在多主機和多程序環境中使用 JAX」。
必要的角色
如要取得建立 TPU 並透過 SSH 連線所需的權限,請要求管理員在專案中授予您下列 IAM 角色:
如要進一步瞭解如何授予角色,請參閱「管理專案、資料夾和組織的存取權」。
建立 Cloud TPU 節點
建立一些環境變數:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5litepod-32 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
使用
gcloud指令建立 TPU 節點。舉例來說,如要建立 v5litepod-32 配量,請使用下列指令:$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
在切片上安裝 JAX
建立 TPU 配量後,您必須在 TPU 配量中的所有主機上安裝 JAX。如要執行這項操作,請使用 gcloud compute tpus tpu-vm ssh 指令,並搭配 --worker=all 和 --commamnd 參數。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
在配量上執行 JAX 程式碼
如要在 TPU 配量上執行 JAX 程式碼,您必須在 TPU 配量中的每個主機上執行程式碼。jax.device_count() 呼叫會停止回應,直到在切片中的每個主機上呼叫為止。以下範例說明如何在 TPU 節點上執行 JAX 計算。
準備程式碼
您需要 gcloud 344.0.0 以上版本 (適用於 scp 指令)。使用 gcloud --version 檢查 gcloud 版本,並視需要執行 gcloud components upgrade。
建立名為 example.py 的檔案,並加入下列程式碼:
import jax
# Initialize the slice
jax.distributed.initialize()
# The total number of TPU cores in the slice
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
將 example.py 複製到配量中的所有 TPU 工作站 VM
$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \ --worker=all \ --zone=${ZONE} \ --project=${PROJECT_ID}
如果您先前未使用 scp 指令,可能會看到類似下列內容的錯誤訊息:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
如要解決錯誤,請執行錯誤訊息中顯示的 ssh-add 指令,然後重新執行指令。
在切片上執行程式碼
在每個 VM 上啟動 example.py 程式:
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="python3 ./example.py"
輸出內容 (使用 v5litepod-32 切片產生):
global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]
清除所用資源
使用完 TPU VM 後,請按照下列步驟清除資源。
刪除 Cloud TPU 和 Compute Engine 資源。
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID}
執行
gcloud compute tpus execution-groups list,確認資源已刪除。刪除作業可能需要幾分鐘才能完成。下列指令的輸出內容不應包含本教學課程中建立的任何資源:$ gcloud compute tpus tpu-vm list --zone=${ZONE} \ --project=${PROJECT_ID}