在 TPU 配量上執行 PyTorch 程式碼
執行本文中的指令前,請務必按照「設定帳戶和 Cloud TPU 專案」一文中的操作說明進行設定。
在單一 TPU VM 上執行 PyTorch 程式碼後,您可以在 TPU 配量上執行程式碼,進一步擴大規模。TPU 配量是指透過專用高速網路連線相互連結的多個 TPU 板。本文將介紹如何在 TPU 配量上執行 PyTorch 程式碼。
建立 Cloud TPU 節點
定義一些環境變數,方便使用指令。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-32 export RUNTIME_VERSION=v2-alpha-tpuv5
執行下列指令,建立 TPU VM:
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
在 Slice 上安裝 PyTorch/XLA
建立 TPU 配量後,您必須在 TPU 配量中的所有主機上安裝 PyTorch。如要執行這項操作,請使用 gcloud compute tpus tpu-vm ssh 指令,並搭配 --worker=all 和 --commamnd 參數。
如果下列指令因 SSH 連線錯誤而失敗,可能是因為 TPU VM 沒有外部 IP 位址。如要存取沒有外部 IP 位址的 TPU VM,請按照「連線至沒有公開 IP 位址的 TPU VM」一文中的操作說明進行。
在所有 TPU VM 工作站上安裝 PyTorch/XLA:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
在所有 TPU VM 工作站上複製 XLA:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="git clone https://github.com/pytorch/xla.git"
在 TPU 節點上執行訓練指令碼
在所有工作人員上執行訓練指令碼。訓練指令碼會使用單一程式多重資料 (SPMD) 分片策略。如要進一步瞭解 SPMD,請參閱 PyTorch/XLA SPMD 使用指南。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="PJRT_DEVICE=TPU python3 ~/xla/test/spmd/test_train_spmd_imagenet.py \ --fake_data \ --model=resnet50 \ --num_epochs=1 2>&1 | tee ~/logs.txt"
訓練過程約需 15 分鐘,完成後,您應該會看到類似下方的訊息:
Epoch 1 test end 23:49:15, Accuracy=100.00
10.164.0.11 [0] Max Accuracy: 100.00%
清除所用資源
使用完 TPU VM 後,請按照下列步驟清除資源。
如果尚未中斷與 Cloud TPU 執行個體的連線,請中斷連線:
(vm)$ exit
系統現在顯示的提示訊息應為
username@projectname,代表您位於 Cloud Shell。刪除 Cloud TPU 資源。
$ gcloud compute tpus tpu-vm delete \ --zone=${ZONE}
執行
gcloud compute tpus tpu-vm list,確認資源已刪除。刪除作業可能需要幾分鐘才能完成。下列指令的輸出內容不應包含本教學課程中建立的任何資源:$ gcloud compute tpus tpu-vm list --zone=${ZONE}