使用 PyTorch 在 Cloud TPU VM 上執行計算

本文簡要說明如何使用 PyTorch 和 Cloud TPU。

事前準備

執行本文中的指令前,請先建立 Google Cloud 帳戶、安裝 Google Cloud CLI,並設定 gcloud 指令。詳情請參閱「設定 Cloud TPU 環境」。

必要的角色

如要取得建立 TPU 並透過 SSH 連線所需的權限,請要求管理員在專案中授予您下列 IAM 角色:

如要進一步瞭解如何授予角色,請參閱「管理專案、資料夾和組織的存取權」。

您或許也能透過自訂角色或其他預先定義的角色,取得必要權限。

使用 gcloud 建立 Cloud TPU

  1. 定義一些環境變數,方便使用指令。

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east5-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 執行下列指令,建立 TPU VM:

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$RUNTIME_VERSION

連線至 Cloud TPU VM

使用下列指令,透過 SSH 連線至 TPU VM:

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE

如果無法使用 SSH 連線至 TPU VM,可能是因為 TPU VM 沒有外部 IP 位址。如要存取沒有外部 IP 位址的 TPU VM,請按照「連線至沒有公開 IP 位址的 TPU VM」一文中的操作說明進行。

在 TPU VM 上安裝 PyTorch/XLA

$ (vm) sudo apt-get update
$ (vm) sudo apt-get install libopenblas-dev -y
$ (vm) pip install numpy
$ (vm) pip install torch torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

確認 PyTorch 可以存取 TPU

使用下列指令確認 PyTorch 可以存取 TPU:

$ (vm) PJRT_DEVICE=TPU python3 -c "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))"

指令輸出內容應如下所示:

['xla:0', 'xla:1', 'xla:2', 'xla:3', 'xla:4', 'xla:5', 'xla:6', 'xla:7']

執行基本計算

  1. 在目前目錄中建立名為 tpu-test.py 的檔案,然後將下列指令碼複製並貼到檔案中:

    import torch
    import torch_xla.core.xla_model as xm
    
    dev = xm.xla_device()
    t1 = torch.randn(3,3,device=dev)
    t2 = torch.randn(3,3,device=dev)
    print(t1 + t2)
    
  2. 執行指令碼:

    (vm)$ PJRT_DEVICE=TPU python3 tpu-test.py

    指令碼輸出內容會顯示運算結果:

    tensor([[-0.2121,  1.5589, -0.6951],
            [-0.7886, -0.2022,  0.9242],
            [ 0.8555, -1.8698,  1.4333]], device='xla:1')
    

清除所用資源

為了避免系統向您的 Google Cloud 帳戶收取本頁面所用資源的費用,請按照下列步驟操作。

  1. 如果尚未中斷與 Cloud TPU 執行個體的連線,請中斷連線:

    (vm)$ exit

    系統現在顯示的提示訊息應為 username@projectname,代表您位於 Cloud Shell。

  2. 刪除 Cloud TPU。

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE
  3. 執行下列指令,確認資源已刪除。確認 TPU 不再列出。刪除作業需要幾分鐘的時間才能完成。

    $ gcloud compute tpus tpu-vm list \
        --zone=$ZONE

後續步驟

進一步瞭解 Cloud TPU VM: