搭配使用 PyTorch 和 Cloud TPU 訓練 Resnet50

這個教學課程將說明如何使用 PyTorch 在 Cloud TPU 裝置中訓練 ResNet-50 模型。如果您有其他已針對 TPU 完成最佳化處理的圖片分類模型,而且這些模型使用的是 PyTorch 和 ImageNet 資料集,您也可以按照這個教學課程中的步驟對其進行訓練。

本教學課程中的模型是以圖像識別的深度殘差學習為基礎,該論文首度提出殘差網路 (ResNet) 架構的概念。這個教學課程使用了含有 50 層架構的變化版本「ResNet-50」,並說明如何使用 PyTorch/XLA 訓練模型。

建立 TPU VM

  1. 開啟 Cloud Shell 視窗。

    開啟 Cloud Shell

  2. 建立 TPU VM

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central1-a \
    --project=your-project
  3. 使用 SSH 連線至 TPU VM:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central1-a
  4. 在 TPU VM 上安裝 PyTorch/XLA:

    (vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
  5. 複製 PyTorch/XLA GitHub 存放區

    (vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
  6. 使用偽造資料執行訓練指令碼

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1