Entrena un modelo con TPU v5e

Con una huella menor de 256 chips por Pod, la TPU v5e se optimizó con el objetivo de ser un producto de alto valor para el entrenamiento, el ajuste y la entrega de Transformers, la función de texto a imagen y redes neuronales convolucionales (CNN). Para obtener más información sobre el uso de Cloud TPU v5e para la entrega, consulta Inferencia con v5e.

Para obtener más información sobre el hardware y los parámetros de configuración de TPU v5e de Cloud TPU, consulta TPU v5e.

Comienza

En las siguientes secciones, se describe cómo comenzar a usar TPU v5e.

Cuota de solicitudes

Necesitas una cuota para usar TPU v5e para el entrenamiento. Existen diferentes tipos de cuotas para las TPU según demanda y las reservadas, y las VMs Spot de TPU. Si usas tu TPU v5e para la inferencia, se requieren cuotas independientes. Para obtener más información sobre las cuotas, consulta Cuotas. Para solicitar cuota de TPU v5e, comunícate con Ventas de Cloud.

Crea una cuenta y un proyecto de Google Cloud

Necesitas una cuenta y un proyecto de Google Cloud para usar Cloud TPU. Para obtener más información, consulta Configura un entorno de Cloud TPU.

Crea una Cloud TPU

La práctica recomendada es aprovisionar Cloud TPU v5e como recursos en cola con el comando queued-resource create. Para obtener más información, consulta Administra recursos en cola.

También puedes usar la API de creación de nodos (gcloud compute tpus tpu-vm create) para aprovisionar Cloud TPU v5e. Para obtener más información, consulta Administra recursos TPU.

Para obtener más información sobre los parámetros de configuración de v5e disponibles para el entrenamiento, consulta Tipos de Cloud TPU v5e para el entrenamiento.

Configura el framework

En esta sección, se describe el proceso de configuración general para el entrenamiento de modelos personalizados con JAX o PyTorch, y TPU v5e.

Si deseas obtener instrucciones para configurar la inferencia, consulta la introducción a la inferencia de v5e.

Define algunas variables de entorno:

export PROJECT_ID=your_project_ID
export ACCELERATOR_TYPE=v5litepod-16
export ZONE=us-west4-a
export TPU_NAME=your_tpu_name
export QUEUED_RESOURCE_ID=your_queued_resource_id

Configura JAX

Si tienes formas de porciones con más de 8 chips, tendrás varias VMs en una porción. En este caso, debes usar la marca --worker=all para ejecutar la instalación en todas las VMs de TPU en un solo paso sin usar SSH para acceder a cada una por separado:

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Descripciones de las marcas de comandos

Variable Descripción
TPU_NAME Es el ID de texto asignado por el usuario para la TPU que se crea cuando se asigna la solicitud de un recurso en cola.
PROJECT_ID En el nombre del proyecto deGoogle Cloud . Usa un proyecto existente o crea uno nuevo en Configura tu proyecto de Google Cloud .
ZONE Consulta el documento Regiones y zonas de TPU para conocer las zonas compatibles.
worker Es la VM de TPU que tiene acceso a las TPU subyacentes.

Puedes ejecutar el siguiente comando para verificar la cantidad de dispositivos (los resultados que se muestran aquí se produjeron con una porción v5litepod-16). Este código verifica que todo esté instalado de forma correcta, ya que comprueba que JAX vea los TensorCores de Cloud TPU y pueda ejecutar operaciones básicas:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

El resultado será similar al siguiente ejemplo:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() muestra la cantidad total de chips en una porción determinada. jax.local_device_count() indica la cantidad de chips a los que puede acceder una sola VM en esta porción.

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

El resultado será similar al siguiente ejemplo:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

Prueba los instructivos de JAX en este documento para comenzar a entrenar el modelo v5e con JAX.

Configura PyTorch

Ten en cuenta que v5e solo admite el entorno de ejecución de PJRT, y PyTorch 2.1 en adelante usarán PJRT como el entorno de ejecución predeterminado para todas las versiones de TPU.

En esta sección, se describe cómo comenzar a usar PJRT en v5e con PyTorch/XLA con comandos para todos los trabajadores.

Instala dependencias

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip install mkl mkl-include
      pip install tf-nightly tb-nightly tbp-nightly
      pip install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

Reemplaza PYTORCH_VERSION por la versión de PyTorch que deseas usar. PYTORCH_VERSION se usa para especificar la misma versión para PyTorch/XLA. Se recomienda la 2.6.0.

Para obtener más información sobre las versiones de PyTorch y PyTorch/XLA, consulta Primeros pasos en PyTorch y Versiones de PyTorch/XLA.

Para obtener más información sobre cómo instalar PyTorch/XLA, consulta Instalación de PyTorch/XLA.

Si recibes un error cuando instalas las ruedas para torch, torch_xla o torchvision, como pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end or semicolon (after name and no valid version specifier) torch==nightly+20230222, cámbiate a una versión inferior con este comando:

pip3 install setuptools==62.1.0

Ejecuta una secuencia de comandos con PJRT

unset LD_PRELOAD

A continuación, se muestra un ejemplo en el que se usa una secuencia de comandos de Python para realizar un cálculo en una VM de v5e:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      unset LD_PRELOAD
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

Esto genera un resultado similar al que se muestra a continuación:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')

Prueba los instructivos de PyTorch en este documento para comenzar a entrenar v5e con PyTorch.

Borra tu TPU y el recurso en cola al final de la sesión. Para borrar un recurso en cola, borra la porción y, luego, el recurso en cola en 2 pasos:

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Estos dos pasos también se pueden usar para quitar las solicitudes de recursos en cola que se encuentran en el estado FAILED.

Ejemplos de JAX/FLAX

En las siguientes secciones, se describen ejemplos de cómo entrenar modelos de JAX y FLAX en TPU v5e.

Entrena ImageNet en v5e

En este instructivo, se describe cómo entrenar ImageNet en v5e con datos de entrada simulados. Si deseas usar datos reales, consulta el archivo readme en GitHub.

Configura

  1. Crea variables de entorno:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Descripciones de las variables de entorno

    Variable Descripción
    PROJECT_ID Es el ID de tu proyecto de Google Cloud . Usa un proyecto existente o crea uno nuevo.
    TPU_NAME Es el nombre de la TPU.
    ZONE Es la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas compatibles, consulta Regiones y zonas de TPU.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    RUNTIME_VERSION Es la versión de software de Cloud TPU.
    SERVICE_ACCOUNT Es la dirección de correo electrónico de tu cuenta de servicio. Para encontrarla, dirígete a la página Cuentas de servicio en la consola de Google Cloud .

    Un ejemplo es tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID Es el ID de texto asignado por el usuario para la solicitud del recurso en cola.

  2. Crea un recurso TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Podrás establecer una conexión SSH a tu VM de TPU una vez que el recurso en cola esté en el estado ACTIVE.

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Cuando el recurso en cola esté en el estado ACTIVE, el resultado será similar al siguiente:

     state: ACTIVE
    
  3. Instala la versión más reciente de JAX y jaxlib:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Clona el modelo ImageNet y, luego, instala los requisitos correspondientes:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
    
  5. Para generar datos simulados, el modelo necesita información sobre las dimensiones del conjunto de datos. Esto se puede recopilar de los metadatos del conjunto de datos de ImageNet:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
    

Entrena el modelo

Una vez que hayas completado todos los pasos anteriores, podrás entrenar el modelo.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"

Borra la TPU y el recurso en cola

Borra tu TPU y el recurso en cola al final de la sesión.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Modelos de FLAX de Hugging Face

Los modelos de Hugging Face implementados en FLAX funcionan de inmediato en Cloud TPU v5e. En esta sección, se brindan instrucciones para ejecutar modelos populares.

Entrena ViT en Imagenette

En este instructivo, se muestra cómo entrenar el modelo Vision Transformer (ViT) de Hugging Face con el conjunto de datos Imagenette de Fast AI en Cloud TPU v5e.

El modelo ViT fue el primero en entrenar con éxito un codificador Transformer en ImageNet con resultados excelentes en comparación con las redes convolucionales. Para obtener más información, consulta Descripción general de ViT.

Configura

  1. Crea variables de entorno:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Descripciones de las variables de entorno

    Variable Descripción
    PROJECT_ID Es el ID de tu proyecto de Google Cloud . Usa un proyecto existente o crea uno nuevo.
    TPU_NAME Es el nombre de la TPU.
    ZONE Es la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas compatibles, consulta Regiones y zonas de TPU.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    RUNTIME_VERSION Es la versión de software de Cloud TPU.
    SERVICE_ACCOUNT Es la dirección de correo electrónico de tu cuenta de servicio. Para encontrarla, dirígete a la página Cuentas de servicio en la consola de Google Cloud .

    Un ejemplo es tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID Es el ID de texto asignado por el usuario para la solicitud del recurso en cola.

  2. Crea un recurso TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Podrás establecer una conexión SSH a tu VM de TPU una vez que el recurso en cola esté en estado ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Cuando el recurso en cola esté en el estado ACTIVE, el resultado será similar al siguiente:

     state: ACTIVE
    
  3. Instala JAX y su biblioteca:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Descarga el repositorio de Hugging Face y, luego, instala los requisitos:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
    
  5. Descarga el conjunto de datos de Imagenette:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
    

Entrena el modelo

Entrena el modelo con un búfer previamente asignado de 4 GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'

Borra la TPU y el recurso en cola

Borra tu TPU y el recurso en cola al final de la sesión.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Resultados de comparativas de ViT

La secuencia de comandos de entrenamiento se ejecutó en v5litepod-4, v5litepod-16 y v5litepod-64. En la siguiente tabla, se muestran las capacidades de procesamiento con diferentes tipos de aceleradores.

Tipo de acelerador v5litepod-4 v5litepod-16 v5litepod-64
Ciclo de entrenamiento 3 3 3
Tamaño del lote global 32 128 512
Capacidad de procesamiento (ejemplos/s) 263.40 429.34 470.71

Entrena la difusión en Pokémon

En este instructivo, se muestra cómo entrenar el modelo Stable Diffusion de Hugging Face con el conjunto de datos de Pokémon en Cloud TPU v5e.

El modelo Stable Diffusion es un modelo latente de texto a imagen que genera imágenes fotorrealistas a partir de cualquier entrada de texto. Para obtener más información, consulta los siguientes recursos:

Configura

  1. Establece una variable de entorno para el nombre de tu bucket de almacenamiento:

    export GCS_BUCKET_NAME=your_bucket_name
  2. Configura un bucket de almacenamiento para el resultado de tu modelo:

    gcloud storage buckets create gs://GCS_BUCKET_NAME \
        --project=your_project \
        --location=us-west1
  3. Crea variables de entorno:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west1-c
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Descripciones de las variables de entorno

    Variable Descripción
    PROJECT_ID Es el ID de tu proyecto de Google Cloud . Usa un proyecto existente o crea uno nuevo.
    TPU_NAME Es el nombre de la TPU.
    ZONE Es la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas compatibles, consulta Regiones y zonas de TPU.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    RUNTIME_VERSION Es la versión de software de Cloud TPU.
    SERVICE_ACCOUNT Es la dirección de correo electrónico de tu cuenta de servicio. Para encontrarla, dirígete a la página Cuentas de servicio en la consola de Google Cloud .

    Un ejemplo es tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID Es el ID de texto asignado por el usuario para la solicitud del recurso en cola.

  4. Crea un recurso TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Podrás establecer una conexión SSH a tu VM de TPU una vez que el recurso en cola esté en el estado ACTIVE.

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Cuando el recurso en cola esté en el estado ACTIVE, el resultado será similar al siguiente:

     state: ACTIVE
    
  5. Instala JAX y su biblioteca.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  6. Descarga el repositorio de Hugging Face y, luego, instala los requisitos.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
         --project=${PROJECT_ID} \
         --zone=${ZONE} \
         --worker=all \
         --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
    

Entrena el modelo

Entrena el modelo con un búfer previamente asignado de 4 GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
    git clone https://github.com/google/maxdiffusion
    cd maxdiffusion
    pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    pip3 install -r requirements.txt
    pip3 install .
    pip3 install gcsfs
    export LIBTPU_INIT_ARGS=''
    python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
    jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
    per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
    output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"

Realiza una limpieza

Borra tu TPU, el recurso en cola y el bucket de Cloud Storage al final de la sesión.

  1. Borra tu TPU:

    gcloud compute tpus tpu-vm delete ${TPU_NAME} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  2. Borra el recurso en cola:

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  3. Borra el bucket de Cloud Storage:

    gcloud storage rm -r gs://${GCS_BUCKET_NAME}
    

Resultados de comparativas para la difusión

La secuencia de comandos de entrenamiento se ejecutó en v5litepod-4, v5litepod-16 y v5litepod-64. En la siguiente tabla, se muestran las capacidades de procesamiento.

Tipo de acelerador v5litepod-4 v5litepod-16 v5litepod-64
Paso del entrenamiento 1500 1500 1500
Tamaño del lote global 32 64 128
Capacidad de procesamiento (ejemplos/s) 36.53 43.71 49.36

PyTorch/XLA

En las siguientes secciones, se describen ejemplos de cómo entrenar modelos de PyTorch/XLA en TPU v5e.

Entrena ResNet con el entorno de ejecución de PJRT

PyTorch/XLA está migrando de XRT a PjRt a partir de PyTorch 2.0 en adelante. A continuación, se incluyen las instrucciones actualizadas sobre cómo configurar v5e para las cargas de trabajo de entrenamiento de PyTorch/XLA.

Configura
  1. Crea variables de entorno:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Descripciones de las variables de entorno

    Variable Descripción
    PROJECT_ID Es el ID de tu proyecto de Google Cloud . Usa un proyecto existente o crea uno nuevo.
    TPU_NAME Es el nombre de la TPU.
    ZONE Es la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas compatibles, consulta Regiones y zonas de TPU.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    RUNTIME_VERSION Es la versión de software de Cloud TPU.
    SERVICE_ACCOUNT Es la dirección de correo electrónico de tu cuenta de servicio. Para encontrarla, dirígete a la página Cuentas de servicio en la consola de Google Cloud .

    Un ejemplo es tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID Es el ID de texto asignado por el usuario para la solicitud del recurso en cola.

  2. Crea un recurso TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Podrás establecer una conexión SSH a tu VM de TPU una vez que tu recurso en cola esté en estado ACTIVE.

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Cuando el recurso en cola esté en el estado ACTIVE, el resultado será similar al siguiente:

     state: ACTIVE
    
  3. Instala las dependencias específicas de Torch/XLA.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project=${PROJECT_ID} \
      --zone=${ZONE} \
      --worker=all \
      --command='
         sudo apt-get update -y
         sudo apt-get install libomp5 -y
         pip3 install mkl mkl-include
         pip3 install tf-nightly tb-nightly tbp-nightly
         pip3 install numpy
         sudo apt-get install libopenblas-dev -y
         pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

    Reemplaza PYTORCH_VERSION por la versión de PyTorch que deseas usar. PYTORCH_VERSION se usa para especificar la misma versión para PyTorch/XLA. Se recomienda la 2.6.0.

    Para obtener más información sobre las versiones de PyTorch y PyTorch/XLA, consulta Primeros pasos en PyTorch y Versiones de PyTorch/XLA.

    Para obtener más información sobre cómo instalar PyTorch/XLA, consulta Instalación de PyTorch/XLA.

Entrena el modelo ResNet
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      date
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export XLA_USE_BF16=1
      export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      git clone https://github.com/pytorch/xla.git
      cd xla/
      git checkout release-r2.6
      python3 test/test_train_mp_imagenet.py --model=resnet50  --fake_data --num_epochs=1 --num_workers=16  --log_steps=300 --batch_size=64 --profile'

Borra la TPU y el recurso en cola

Borra tu TPU y el recurso en cola al final de la sesión.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
Resultado de la comparativa

En la siguiente tabla, se muestran las capacidades de procesamiento comparativas.

Tipo de acelerador Capacidad de procesamiento (ejemplos/segundo)
v5litepod-4 4,240 ej./s
v5litepod-16 10,810 ej./s
v5litepod-64 46,154 ej./s

Entrena ViT en v5e

En este instructivo, se explica cómo ejecutar ViT en v5e con el repositorio de Hugging Face en PyTorch/XLA en el conjunto de datos cifar10.

Configura

  1. Crea variables de entorno:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Descripciones de las variables de entorno

    Variable Descripción
    PROJECT_ID Es el ID de tu proyecto de Google Cloud . Usa un proyecto existente o crea uno nuevo.
    TPU_NAME Es el nombre de la TPU.
    ZONE Es la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas compatibles, consulta Regiones y zonas de TPU.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    RUNTIME_VERSION Es la versión de software de Cloud TPU.
    SERVICE_ACCOUNT Es la dirección de correo electrónico de tu cuenta de servicio. Para encontrarla, dirígete a la página Cuentas de servicio en la consola de Google Cloud .

    Un ejemplo es tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID Es el ID de texto asignado por el usuario para la solicitud del recurso en cola.

  2. Crea un recurso TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Podrás establecer una conexión SSH a tu VM de TPU una vez que tu recurso en cola esté en el estado ACTIVE:

     gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Cuando el recurso en cola esté en el estado ACTIVE, el resultado será similar al siguiente:

     state: ACTIVE
    
  3. Instala las dependencias de PyTorch/XLA.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip3 install mkl mkl-include
      pip3 install tf-nightly tb-nightly tbp-nightly
      pip3 install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
      pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

    Reemplaza PYTORCH_VERSION por la versión de PyTorch que deseas usar. PYTORCH_VERSION se usa para especificar la misma versión para PyTorch/XLA. Se recomienda la 2.6.0.

    Para obtener más información sobre las versiones de PyTorch y PyTorch/XLA, consulta Primeros pasos en PyTorch y Versiones de PyTorch/XLA.

    Para obtener más información sobre cómo instalar PyTorch/XLA, consulta Instalación de PyTorch/XLA.

  4. Descarga el repositorio de Hugging Face y, luego, instala los requisitos.

       gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="
          git clone https://github.com/suexu1025/transformers.git vittransformers; \
          cd vittransformers; \
          pip3 install .; \
          pip3 install datasets; \
          wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
    

Entrena el modelo

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export TF_CPP_MIN_LOG_LEVEL=0
      export XLA_USE_BF16=1
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      cd vittransformers
      python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
      --remove_unused_columns=False \
      --label_names=pixel_values \
      --mask_ratio=0.75 \
      --norm_pix_loss=True \
      --do_train=true \
      --do_eval=true \
      --base_learning_rate=1.5e-4 \
      --lr_scheduler_type=cosine \
      --weight_decay=0.05 \
      --num_train_epochs=3 \
      --warmup_ratio=0.05 \
      --per_device_train_batch_size=8 \
      --per_device_eval_batch_size=8 \
      --logging_strategy=steps \
      --logging_steps=30 \
      --evaluation_strategy=epoch \
      --save_strategy=epoch \
      --load_best_model_at_end=True \
      --save_total_limit=3 \
      --seed=1337 \
      --output_dir=MAE \
      --overwrite_output_dir=true \
      --logging_dir=./tensorboard-metrics \
      --tpu_metrics_debug=true'

Borra la TPU y el recurso en cola

Borra tu TPU y el recurso en cola al final de la sesión.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Resultado de la comparativa

En la siguiente tabla, se muestran las comparativas de las capacidades de procesamiento para los diferentes tipos de aceleradores.

v5litepod-4 v5litepod-16 v5litepod-64
Ciclo de entrenamiento 3 3 3
Tamaño del lote global 32 128 512
Capacidad de procesamiento (ejemplos/s) 201 657 2,844