Entrena Resnet50 en Cloud TPU con PyTorch

En este instructivo, se muestra cómo entrenar el modelo ResNet-50 en un dispositivo Cloud TPU con PyTorch. Puedes aplicar el mismo patrón a otros modelos de clasificación de imágenes optimizados con TPU que usen PyTorch y el conjunto de datos ImageNet.

El modelo de este instructivo se basa en Aprendizaje residual profundo para el reconocimiento de imágenes, que presentó por primera vez la arquitectura de la red residual (ResNet). En el instructivo, se usa la variante de 50 capas, ResNet-50, y se muestra cómo entrenar el modelo con PyTorch/XLA.

Crea una VM de TPU

  1. Abre una ventana de Cloud Shell.

    Abra Cloud Shell

  2. Crea una VM de TPU

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central1-a \
    --project=your-project
  3. Conéctate a tu VM de TPU con SSH:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central1-a
  4. Instala PyTorch/XLA en tu VM de TPU:

    (vm)$ pip install torch torch_xla[tpu] torchvision -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
  5. Clona el repositorio de GitHub de PyTorch/XLA

    (vm)$ git clone --depth=1 https://github.com/pytorch/xla.git
  6. Ejecuta la secuencia de comandos de entrenamiento con datos falsos

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1