在 TPU 切片上运行 PyTorch 代码
在运行本文档中的命令之前,请确保已按照设置账号和 Cloud TPU 项目中的说明操作。
在单个 TPU 虚拟机上运行 PyTorch 代码后,您可以通过在 TPU 切片上运行代码来扩容代码。TPU 切片是通过专用高速网络连接相互连接的多个 TPU 板。本文档介绍了如何在 TPU 切片上运行 PyTorch 代码。
所需的角色
如需获得创建 TPU 并使用 SSH 连接到该 TPU 所需的权限,请让您的管理员为您授予项目的以下 IAM 角色:
-
TPU Admin (
roles/tpu.admin) -
Service Account User (
roles/iam.serviceAccountUser) -
Compute Viewer (
roles/compute.viewer)
如需详细了解如何授予角色,请参阅管理对项目、文件夹和组织的访问权限。
创建 Cloud TPU 切片
定义一些环境变量,以便更轻松地使用命令。
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-32 export RUNTIME_VERSION=v2-alpha-tpuv5
环境变量说明
变量 说明 PROJECT_ID您的 Google Cloud 项目 ID。使用现有项目或创建新项目。 TPU_NAMETPU 的名称。 ZONE要在其中创建 TPU 虚拟机的可用区。如需详细了解支持的可用区,请参阅 TPU 区域和可用区。 ACCELERATOR_TYPE加速器类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本。 RUNTIME_VERSIONCloud TPU 软件版本。 通过运行以下命令创建 TPU 虚拟机:
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
在切片上安装 PyTorch/XLA
创建 TPU 切片后,您必须在 TPU 切片中的所有主机上安装 PyTorch。您可以使用 gcloud compute tpus tpu-vm ssh 命令以及 --worker=all 和 --command 参数来执行此操作。
如果以下命令因 SSH 连接错误而失败,可能是因为 TPU 虚拟机没有外部 IP 地址。如需访问没有外部 IP 地址的 TPU 虚拟机,请按照连接到没有公共 IP 地址的 TPU 虚拟机中的说明操作。
在所有 TPU 虚拟机工作器上安装 PyTorch/XLA:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
在所有 TPU 虚拟机工作器上克隆 XLA:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="git clone https://github.com/pytorch/xla.git"
在 TPU 切片上运行训练脚本
在所有工作器上运行训练脚本。训练脚本使用单程序多数据 (SPMD) 分片策略。如需详细了解 SPMD,请参阅 PyTorch/XLA SPMD 用户指南。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="PJRT_DEVICE=TPU python3 ~/xla/test/spmd/test_train_spmd_imagenet.py \ --fake_data \ --model=resnet50 \ --num_epochs=1 2>&1 | tee ~/logs.txt"
训练大约需要 15 分钟。完成后,您应该会看到如下所示的消息:
Epoch 1 test end 23:49:15, Accuracy=100.00
10.164.0.11 [0] Max Accuracy: 100.00%
清理
完成 TPU 虚拟机的操作后,请按照以下步骤清理资源。
断开与 Cloud TPU 实例的连接(如果您尚未这样做):
(vm)$ exit
您的提示符现在应为
username@projectname,表明您位于 Cloud Shell 中。删除 Cloud TPU 资源。
$ gcloud compute tpus tpu-vm delete \ --zone=${ZONE}
通过运行
gcloud compute tpus tpu-vm list验证资源是否已删除。删除操作可能需要几分钟时间才能完成。以下命令的输出不应包含本教程中创建的任何资源:$ gcloud compute tpus tpu-vm list --zone=${ZONE}