Melatih Resnet50 di Cloud TPU dengan PyTorch

Tutorial ini menunjukkan cara melatih model ResNet-50 di perangkat Cloud TPU dengan PyTorch. Anda dapat menerapkan pola yang sama ke model klasifikasi gambar yang dioptimalkan untuk TPU lainnya yang menggunakan PyTorch dan set data ImageNet.

Model dalam tutorial ini didasarkan pada Deep Residual Learning for Image Recognition, yang pertama kali memperkenalkan arsitektur jaringan residual (ResNet). Tutorial ini menggunakan varian 50 lapisan, ResNet-50, dan menunjukkan cara melatih model menggunakan PyTorch/XLA.

Membuat VM TPU

  1. Buka jendela Cloud Shell.

    Buka Cloud Shell

  2. Membuat VM TPU

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central1-a \
    --project=your-project
  3. Hubungkan ke VM TPU Anda menggunakan SSH:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central1-a
  4. Instal PyTorch/XLA di VM TPU Anda:

    (vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
  5. Clone repositori GitHub PyTorch/XLA

    (vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
  6. Jalankan skrip pelatihan dengan data palsu

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1