Entrenar Resnet50 en TPU de Cloud con PyTorch

En este tutorial se explica cómo entrenar el modelo ResNet-50 en un dispositivo TPU de Cloud con PyTorch. Puedes aplicar el mismo patrón a otros modelos de clasificación de imágenes optimizados para TPUs que usen PyTorch y el conjunto de datos ImageNet.

El modelo de este tutorial se basa en Aprendizaje residual profundo para el reconocimiento de imágenes, que introduce por primera vez la arquitectura de la red residual (ResNet). En el tutorial se usa la variante de 50 capas, ResNet-50, y se muestra cómo entrenar el modelo con PyTorch/XLA.

Crear una VM de TPU

  1. Abre una ventana de Cloud Shell.

    Abrir Cloud Shell

  2. Crear una VM de TPU

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central1-a \
    --project=your-project
    Google Cloud
  3. Conéctate a tu VM de TPU mediante SSH:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central1-a
  4. Instala PyTorch/XLA en tu VM de TPU:

    (vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
  5. Clona el repositorio de GitHub PyTorch/XLA.

    (vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
  6. Ejecuta la secuencia de comandos de entrenamiento con datos falsos

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1