Ejecuta un cálculo en una VM de Cloud TPU con JAX

En este documento, se proporciona una breve introducción sobre cómo trabajar con JAX y Cloud TPU.

Antes de comenzar

Antes de ejecutar los comandos de este documento, debes crear una cuenta de Google Cloud, instalar Google Cloud CLI y configurar el comando gcloud. Para obtener más información, consulta Configura el entorno de Cloud TPU.

Crea una VM de Cloud TPU con gcloud

  1. Define algunas variables de entorno para que los comandos sean más fáciles de usar.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east5-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    Descripciones de las variables de entorno

    Variable Descripción
    PROJECT_ID Es el ID de tu proyecto de Google Cloud . Usa un proyecto existente o crea uno nuevo.
    TPU_NAME Es el nombre de la TPU.
    ZONE Es la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas compatibles, consulta Regiones y zonas de TPU.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    RUNTIME_VERSION Es la versión de software de Cloud TPU.

  2. Ejecuta el siguiente comando desde Cloud Shell o la terminal de tu computadora en la que esté instalada Google Cloud CLI para crear tu VM de TPU.

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$RUNTIME_VERSION

Conéctate a la VM de tu Cloud TPU

Conéctate a tu VM de TPU con SSH con el siguiente comando:

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE

Si no puedes conectarte a una VM de TPU con SSH, es posible que la VM de TPU no tenga una dirección IP externa. Para acceder a una VM de TPU sin una dirección IP externa, sigue las instrucciones que se indican en Conéctate a una VM de TPU sin una dirección IP pública.

Instala JAX en tu VM de Cloud TPU

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Verificación del sistema

Verifica que JAX pueda acceder a la TPU y ejecutar operaciones básicas:

  1. Inicia el intérprete de Python 3:

    (vm)$ python3
    >>> import jax
  2. Muestra la cantidad de núcleos de TPU disponibles:

    >>> jax.device_count()

Se muestra la cantidad de núcleos de TPU. La cantidad de núcleos que se muestra depende de la versión de TPU que usas. Para obtener más información, consulta Versiones de TPU.

Cómo hacer un cálculo

>>> jax.numpy.add(1, 1)

Se muestra el resultado de la suma de NumPy:

Resultado del comando:

Array(2, dtype=int32, weak_type=True)

Sal del intérprete de Python

>>> exit()

Ejecuta el código JAX en una VM de TPU

Ahora puedes ejecutar cualquier código de JAX que desees. Los ejemplos de Flax son un excelente punto de partida para ejecutar modelos del AA estándar en JAX. Por ejemplo, para entrenar una red convolucional básica de MNIST, sigue los siguientes pasos:

  1. Instala las dependencias de los ejemplos de Flax:

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. Instala Flax:

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. Ejecuta la secuencia de comandos de entrenamiento de MNIST de Flax:

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
        --config=configs/default.py \
        --config.learning_rate=0.05 \
        --config.num_epochs=5

La secuencia de comandos descarga el conjunto de datos y empieza el entrenamiento. El resultado de la secuencia de comandos debería verse de la siguiente manera:

I0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Realiza una limpieza

Sigue estos pasos para evitar que se apliquen cargos a tu cuenta de Google Cloud por los recursos que usaste en esta página.

Cuando termines de usar la VM de TPU, sigue estos pasos para limpiar los recursos.

  1. Desconéctate de la instancia de Cloud TPU, si aún no lo hiciste:

    (vm)$ exit

    Tu prompt ahora debería ser username@projectname, lo que indica que estás en Cloud Shell.

  2. Borra tu Cloud TPU:

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE
  3. Ejecuta el siguiente comando para verificar que los recursos se hayan borrado. Asegúrate de que tu TPU ya no aparezca en la lista. Este proceso puede tardar varios minutos.

    $ gcloud compute tpus tpu-vm list \
        --zone=$ZONE

Notas de rendimiento

Estos son algunos detalles importantes que son particularmente relevantes para usar TPU en JAX.

Relleno

Una de las causas más comunes del rendimiento lento en las TPU es la introducción de padding inadvertido:

  • Los arreglos en Cloud TPU están en mosaicos. Esto implica el relleno de una de las dimensiones a un múltiplo de 8 y de otra a un múltiplo de 128.
  • La unidad de multiplicación de matrices funciona mejor con pares de matrices grandes que minimizan la necesidad de relleno.

bfloat16 dtype

De forma predeterminada, la multiplicación de matrices en JAX en las TPU usa bfloat16 con acumulación de float32. Esto se puede controlar con el argumento de precisión en las llamadas a funciones jax.numpy relevantes (matmul, dot, einsum, etc.). En particular:

  • precision=jax.lax.Precision.DEFAULT: Usa precisión mixta de bfloat16 (la más rápida)
  • precision=jax.lax.Precision.HIGH: Usa varios pases de MXU para lograr una mayor precisión
  • precision=jax.lax.Precision.HIGHEST: Usa aún más pases de MXU para lograr una precisión completa de float32

JAX también agrega el tipo de datos bfloat16, que puedes usar para convertir de forma explícita los arrays a bfloat16. Por ejemplo, jax.numpy.array(x, dtype=jax.numpy.bfloat16).

¿Qué sigue?

Para obtener más información sobre Cloud TPU, consulta los siguientes vínculos: