搭配使用 PyTorch 和 Cloud TPU 訓練 Resnet50

這個教學課程將說明如何使用 PyTorch 在 Cloud TPU 裝置中訓練 ResNet-50 模型。如果您有其他已針對 TPU 完成最佳化處理的圖片分類模型,而且這些模型使用的是 PyTorch 和 ImageNet 資料集,您也可以按照這個教學課程中的步驟對其進行訓練。

本教學課程中的模型是以圖像識別的深度殘差學習為基礎,該論文首度提出殘差網路 (ResNet) 架構的概念。這個教學課程使用了含有 50 層架構的變化版本「ResNet-50」,並說明如何使用 PyTorch/XLA 訓練模型。

目標

  • 準備資料集。
  • 執行訓練工作。
  • 驗證輸出結果。

費用

在本文件中,您會使用下列 Google Cloud的計費元件:

  • Compute Engine
  • Cloud TPU

您可以使用 Pricing Calculator,根據預測用量估算費用。

初次使用 Google Cloud 的使用者可能符合免費試用期資格。

事前準備

開始學習這個教學課程之前,請先檢查 Google Cloud 專案設定是否正確。

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  3. Verify that billing is enabled for your Google Cloud project.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator role (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  5. Verify that billing is enabled for your Google Cloud project.

  6. 本逐步操作說明使用 Google Cloud的計費元件,請參閱 Cloud TPU 定價頁面來估算費用。使用完畢後,請務必清除您建立的資源,以免產生不必要的費用。

建立 TPU VM

  1. 開啟 Cloud Shell 視窗。

    開啟 Cloud Shell

  2. 建立 TPU VM

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central1-a \
    --project=your-project
  3. 使用 SSH 連線至 TPU VM:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central1-a
  4. 在 TPU VM 上安裝 PyTorch/XLA:

    (vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
  5. 複製 PyTorch/XLA GitHub 存放區

    (vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
  6. 使用偽造資料執行訓練指令碼

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

清除所用資源

為避免因為本教學課程所用資源,導致系統向 Google Cloud 收取費用,請刪除含有相關資源的專案,或者保留專案但刪除個別資源。

  1. 中斷與 TPU VM 的連線:

    (vm) $ exit

    系統現在顯示的提示訊息應為 username@projectname,代表您位於 Cloud Shell。

  2. 刪除 TPU VM。

    $ gcloud compute tpus tpu-vm delete your-tpu-name \
       --zone=us-central1-a

後續步驟