Entraîner Resnet50 sur Cloud TPU avec PyTorch

Ce tutoriel explique comment entraîner le modèle ResNet-50 sur un appareil Cloud TPU avec PyTorch. La même procédure peut s'appliquer à d'autres modèles de classification d'image optimisés pour TPU, qui utilisent PyTorch et l'ensemble de données ImageNet.

Le modèle utilisé dans ce tutoriel est basé sur l'article Deep Residual Learning for Image Recognition (Deep learning résiduel pour la reconnaissance d'images), qui présente l'architecture de réseau résiduel (ResNet). Le tutoriel emploie la variante à 50 couches, ResNet-50, et illustre l'entraînement du modèle à l'aide de PyTorch/XLA.

Créez une VM TPU.

  1. Ouvrez une fenêtre Cloud Shell.

    Ouvrir Cloud Shell

  2. Créez une VM TPU.

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central1-a \
    --project=your-project
  3. Connectez-vous à votre VM TPU à l'aide de SSH :

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central1-a
  4. Installez PyTorch/XLA sur votre VM TPU :

    (vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
  5. Clonez le dépôt GitHub PyTorch/XLA.

    (vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
  6. Exécutez le script d'entraînement avec des données factices.

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1