Ejecuta el código JAX en porciones de TPU
Antes de ejecutar los comandos de este documento, asegúrate de haber seguido las instrucciones que se indican en Configura una cuenta y un proyecto de Cloud TPU.
Una vez que tu código JAX se ejecute en un único panel de TPU, puedes escalar verticalmente en una porción de pod de TPU. Las porciones de pod de TPU son varios paneles de TPU conectados entre sí en conexiones de red dedicadas de alta velocidad. Este documento es una introducción a la ejecución de código JAX en porciones de pod de TPU. Para obtener información más detallada, consulta Usa JAX en entornos de hosts y procesos múltiples.
Crea una porción de Cloud TPU
Crea algunas variables de entorno:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5litepod-32 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
Descripciones de las variables de entorno
Variable Descripción PROJECT_IDEs el ID de tu proyecto de Google Cloud . Usa un proyecto existente o crea uno nuevo. TPU_NAMEEs el nombre de la TPU. ZONEEs la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas compatibles, consulta Regiones y zonas de TPU. ACCELERATOR_TYPEEl tipo de acelerador especifica la versión y el tamaño de Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU. RUNTIME_VERSIONEs la versión de software de Cloud TPU. Crea una porción de TPU con el comando
gcloud. Por ejemplo, para crear una porción de v5litepod-32, usa el siguiente comando:$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
Instala JAX en tu porción
Después de crear la porción de TPU, debes instalar JAX en todos los hosts de la porción
de TPU. Puedes hacerlo con el comando gcloud compute tpus tpu-vm ssh y los
parámetros --worker=all y --commamnd.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Ejecuta el código JAX en la porción
Para ejecutar el código JAX en una porción de pod de TPU, debes ejecutar el código en
cada host en la porción de pod de TPU. La llamada a jax.device_count() deja de responder hasta que se
la llama en cada host de la porción. En el siguiente ejemplo, se ilustra cómo ejecutar un
cálculo de JAX en una porción de TPU.
Prepara el código
Necesitas la versión gcloud >=344.0.0 (para el
comando scp).
Usa gcloud --version para verificar tu versión de gcloud y
ejecuta gcloud components upgrade, si es necesario.
Crea un archivo llamado example.py con el siguiente código:
import jax
# The total number of TPU cores in the slice
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
Copia example.py en todas las VMs de trabajador TPU de la porción
$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \ --worker=all \ --zone=${ZONE} \ --project=${PROJECT_ID}
Si no usaste el comando scp anteriormente, es posible que veas un
error similar al siguiente:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
Para resolver el error, ejecuta el comando ssh-add como se muestra en el
mensaje de error y vuelve a ejecutarlo.
Ejecuta el código en la porción
Inicia el programa example.py en cada VM:
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="python3 ./example.py"
Resultado (producido con una porción de v5litepod-32):
global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]
Realiza una limpieza
Cuando termines de usar la VM de TPU, sigue estos pasos para limpiar los recursos.
Borra tus recursos de Cloud TPU y Compute Engine.
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Ejecuta
gcloud compute tpus execution-groups listpara verificar que los recursos se hayan borrado. Este proceso puede tardar varios minutos. El resultado del siguiente comando no debe incluir ninguno de los recursos creados en este instructivo:$ gcloud compute tpus tpu-vm list --zone=${ZONE} \ --project=${PROJECT_ID}