Ejecuta código de PyTorch en porciones de TPU

Antes de ejecutar los comandos de este documento, asegúrate de haber seguido las instrucciones que se indican en Configura una cuenta y un proyecto de Cloud TPU.

Una vez que tu código PyTorch se ejecute en una sola VM de TPU, puedes escalarlo verticalmente ejecutándolo en una porción de TPU. Las porciones de pod de TPU son varios paneles de TPU conectados entre sí en conexiones de red dedicadas de alta velocidad. En este documento, se presenta una introducción a la ejecución de código de PyTorch en porciones de TPU.

Crea una porción de Cloud TPU

  1. Define algunas variables de entorno para que los comandos sean más fáciles de usar.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-32
    export RUNTIME_VERSION=v2-alpha-tpuv5

    Descripciones de las variables de entorno

    Variable Descripción
    PROJECT_ID Es el ID de tu proyecto de Google Cloud . Usa un proyecto existente o crea uno nuevo.
    TPU_NAME Es el nombre de la TPU.
    ZONE Es la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas compatibles, consulta Regiones y zonas de TPU.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    RUNTIME_VERSION Es la versión de software de Cloud TPU.

  2. Ejecuta el siguiente comando para ejecutar tu VM de TPU:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --version=${RUNTIME_VERSION}

Instala PyTorch/XLA en tu porción

Después de crear la porción de TPU, debes instalar PyTorch en todos los hosts de la porción de TPU. Puedes hacerlo con el comando gcloud compute tpus tpu-vm ssh y los parámetros --worker=all y --commamnd.

Si los siguientes comandos fallan debido a un error de conexión SSH, es posible que las VMs de TPU no tengan direcciones IP externas. Para acceder a una VM de TPU sin una dirección IP externa, sigue las instrucciones que se indican en Conéctate a una VM de TPU sin una dirección IP pública.

  1. Instala PyTorch/XLA en todos los trabajadores de la VM de TPU:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --worker=all \
        --command="pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. Clona XLA en todos los trabajadores de la VM de TPU:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --worker=all \
        --command="git clone https://github.com/pytorch/xla.git"

Ejecuta una secuencia de comandos de entrenamiento en tu porción de TPU

Ejecuta la secuencia de comandos de entrenamiento en todos los trabajadores. La secuencia de comandos de entrenamiento usa una estrategia de fragmentación de datos múltiples y programa único (SPMD). Para obtener más información sobre SPMD, consulta la Guía del usuario de SPMD para PyTorch/XLA.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --worker=all \
   --command="PJRT_DEVICE=TPU python3 ~/xla/test/spmd/test_train_spmd_imagenet.py  \
   --fake_data \
   --model=resnet50  \
   --num_epochs=1 2>&1 | tee ~/logs.txt"

El entrenamiento tarda alrededor de 15 minutos. Cuando se complete, deberías ver un mensaje similar al siguiente:

Epoch 1 test end 23:49:15, Accuracy=100.00
     10.164.0.11 [0] Max Accuracy: 100.00%

Realiza una limpieza

Cuando termines de usar la VM de TPU, sigue estos pasos para limpiar los recursos.

  1. Desconéctate de la instancia de Cloud TPU, si aún no lo hiciste:

    (vm)$ exit

    El mensaje ahora debería mostrar username@projectname, que indica que estás en Cloud Shell.

  2. Borra tus recursos de Cloud TPU.

    $ gcloud compute tpus tpu-vm delete  \
        --zone=${ZONE}
  3. Ejecuta gcloud compute tpus tpu-vm list para verificar que los recursos se hayan borrado. Este proceso puede tardar varios minutos. El resultado del siguiente comando no debe incluir ninguno de los recursos creados en este instructivo:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE}