使用 PyTorch 在 Cloud TPU 上训练 Resnet50

本教程介绍如何使用 PyTorch 在 Cloud TPU 设备上训练 ResNet-50 模型。您可以将同一模式应用于使用 PyTorch 和 ImageNet 数据集的其他针对 TPU 进行了优化的图片分类模型。

本教程中的模型基于用于图片识别的深度残差学习,率先引入了残差网络 (ResNet) 架构。本教程使用 50 层变体 ResNet-50,演示如何使用 PyTorch/XLA 训练模型。

创建 TPU 虚拟机

  1. 打开一个 Cloud Shell 窗口。

    打开 Cloud Shell

  2. 创建 TPU 虚拟机

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central1-a \
    --project=your-project
  3. 使用 SSH 连接到 TPU 虚拟机:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central1-a
  4. 在 TPU 虚拟机上安装 PyTorch/XLA:

    (vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
  5. 克隆 PyTorch/XLA GitHub 代码库

    (vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
  6. 使用虚构数据运行训练脚本

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1