Como treinar o Resnet50 no Cloud TPU com PyTorch

Neste tutorial, mostramos como treinar o modelo ResNet-50 em um dispositivo Cloud TPU com PyTorch. É possível aplicar o mesmo padrão a outros modelos de classificação de imagem otimizados para TPU que usam o PyTorch e o conjunto de dados do ImageNet.

O modelo deste tutorial é baseado no artigo Deep Residual Learning for Image Recognition, que foi o primeiro a apresentar a arquitetura de rede residual ou ResNet. O tutorial usa a variante de 50 camadas, ResNet-50, e demonstra o treinamento do modelo usando PyTorch/XLA.

Criar uma VM de TPU

  1. Abra uma janela do Cloud Shell.

    Abra o Cloud Shell

  2. Criar uma VM de TPU

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central1-a \
    --project=your-project
  3. Conecte-se à VM da TPU usando SSH:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central1-a
  4. Instale o PyTorch/XLA na VM de TPU:

    (vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
  5. Clone o repositório do PyTorch/XLA no GitHub.

    (vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
  6. Executar o script de treinamento com dados falsos

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1