PyTorch を使用した Cloud TPU での Resnet50 のトレーニング

このチュートリアルでは、PyTorch を使用して Cloud TPU デバイスで ResNet-50 モデルをトレーニングする方法を説明します。PyTorch と ImageNet データセットを使用する、TPU 用に最適化されたその他のイメージ分類モデルにも、同じパターンを適用できます。

このチュートリアルのモデルは、残余ネットワーク(ResNet)アーキテクチャを最初に導入する画像認識のためのディープ残余ラーニングに基づいています。このチュートリアルでは、50 層のバリアントの ResNet-50 を使用して、PyTorch/XLA を使ったモデルのトレーニング方法を説明します。

TPU VM を作成する

  1. Cloud Shell ウィンドウを開きます。

    Cloud Shell を開く

  2. TPU VM を作成する

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central1-a \
    --project=your-project
  3. SSH を使用して TPU VM に接続します。

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central1-a
  4. TPU VM に PyTorch/XLA をインストールします。

    (vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
  5. PyTorch/XLA GitHub リポジトリのクローンを作成する

    (vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
  6. 架空のデータでトレーニング スクリプトを実行する

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1