Cómo portar cargas de trabajo de JAX a Pathways

Debido a la naturaleza distribuida de JAX con Pathways, es posible que algunas operaciones no se escalen bien debido a la sobrecarga de comunicación. Si bien Pathways minimiza estos gastos generales con funciones como el envío asíncrono, hay algunas cosas que debes tener en cuenta cuando transfieres cargas de trabajo de JAX a Pathways o cuando escalas una carga de trabajo de JAX con Pathways a una gran cantidad de aceleradores.

Antes de comenzar

Asegúrate de tener lo siguiente:

Índice de proceso

JAX con Pathways trata todos los dispositivos de tu clúster de Pathways como locales. Esto simplifica la administración de dispositivos y permite que JAX utilice todos los recursos disponibles. En la práctica, esto significa lo siguiente:

  • jax.process_index() siempre es 0 para todos los dispositivos.
  • jax.devices() y jax.local_devices() devuelven todos los dispositivos de TPU en todo el trabajo.

Tipo de hardware y ubicación

Para obtener el mejor rendimiento, coloca todos los componentes de Pathways y el trabajo del usuario en la misma Google Cloud zona de nube. Usa una CPU grande, como el proxy y el administrador de recursos de IFRT. Recomendamos al menos un n2-standard-64 dedicado con 64 CPU virtuales y 256 GB de memoria.

PathwaysUtils

Pathways-utils es un repositorio de GitHub basado en Python que proporciona utilidades y herramientas esenciales que te permiten optimizar la implementación y la ejecución de cargas de trabajo de JAX en la arquitectura de Pathways on Cloud. Este paquete controla las adaptaciones necesarias para el entorno de la nube, lo que permite que los desarrolladores de JAX se enfoquen en sus flujos de trabajo principales de aprendizaje automático con una configuración mínima específica de la plataforma. Específicamente, ofrece lo siguiente:

  • Un backend de JAX "proxy": Este backend personalizado permite que tu aplicación de JAX use la infraestructura de Pathways configurando la variable de entorno JAX_PLATFORMS=proxy.
  • Utilidades de generación de perfiles integradas: Capacidades de generación de perfiles que te permiten comprender el rendimiento de tu aplicación. Si usas las APIs de generación de perfiles de JAX estándar, como jax.profiler.start_trace y jax.profiler.start_server, puedes generar perfiles no solo de tu código de JAX, sino también de los componentes subyacentes de Pathways, lo que proporciona una vista integral de la ejecución dentro del entorno de la nube.
  • Creación de puntos de control distribuidos con Orbax: Un controlador de puntos de control de Orbax personalizado que te permite usar puntos de control distribuidos y restablecer tus puntos de control cuando usas la biblioteca de Orbax en el entorno de Pathways. El objetivo de esta integración es funcionar sin requerir cambios en tu código de guardado de puntos de control de Orbax existente, siempre y cuando importe pathwaysutils.
  • Primitivas de entrenamiento elástico: Proporciona primitivas de entrenamiento elástico fundamentales que puedes usar para compilar flujos de trabajo de entrenamiento sólidos y escalables con Pathways. Estas primitivas permiten que tus trabajos de entrenamiento se adapten de forma dinámica a los cambios en los recursos disponibles, lo que mejora la eficiencia y la resiliencia en los entornos de nube.

Creación de puntos de control

Orbax se probó exhaustivamente con Pathways para la creación y el restablecimiento de puntos de control distribuidos con Cloud Storage. Cuando llamas a import pathwaysutils; pathwaysutils.initialize() en tu train.py, se registra un ArrayHandler personalizado que controla de manera eficiente las operaciones de puntos de control a través del proxy de IFRT, lo que permite que los trabajadores de Pathways en aceleradores guarden y restablezcan datos directamente.

Python ubicado en el mismo lugar

Python colocalizado es una API de JAX de código abierto que te permite ejecutar código de Python especificado por el usuario directamente en los hosts de TPU o GPU, lo que es más sencillo en JAX con varios controladores. Esto permite que las tareas que requieren más procesamiento, como la carga de datos y la creación de puntos de control, eviten la transferencia de datos entre las máquinas cliente y las TPU. Para configurar tu clúster de Pathways para que ejecute la API de JAX de Python colocada, sigue las instrucciones que se indican en el README de Python colocado. Estas instrucciones explican cómo iniciar un proceso secundario de Python colocado junto con tus trabajadores de Pathways.

Carga de datos

Durante el entrenamiento, cargamos lotes de un conjunto de datos de forma repetida para alimentar el modelo. Es importante tener un cargador de datos asíncrono y eficiente que divida el lote en fragmentos entre los hosts para evitar que los aceleradores se queden sin trabajo. Cuando se ejecuta el entrenamiento con Pathways, el cargador de datos se ejecuta en una VM de CPU (a diferencia de una VM de TPU que se usa en configuraciones de varios controladores) y envía datos a las VMs de TPU. Esto genera una latencia más alta en la lectura de datos, pero se mitiga parcialmente leyendo por adelantado una cantidad X de lotes en el host de la CPU y enviando los datos leídos de forma asíncrona a las TPU. Esta solución es suficiente cuando se ejecuta a una escala pequeña o mediana.

Para obtener un rendimiento óptimo a gran escala, te recomendamos que ubiques tu canalización de datos de entrada en el mismo lugar usando Python ubicado en el mismo lugar para ejecutar tu canalización de datos directamente en los aceleradores. Esto elimina el cuello de botella de la CPU y aprovecha las interconexiones rápidas de la TPU para la transferencia de datos.

Puedes encontrar una implementación de referencia de la migración de una canalización de entrada basada en TFDS en la implementación de RemoteIterator en multihost_dataloading.py. Esta implementación funciona tanto en JAX con varios controladores como en Pathways de forma distribuida con la API de Python JAX colocada.

Control de versiones de Jax

Las versiones de Pathways están estrechamente vinculadas con las versiones de JAX para garantizar la compatibilidad y la estabilidad. Para evitar posibles problemas, verifica que tus artefactos de Pathways y tu versión de JAX estén alineados. Cada versión de Pathways especifica claramente las versiones de JAX compatibles a través de una etiqueta con el formato jax-<version>.

Caché de compilación

La caché de compilación persistente de Pathways es una función que permite que los servidores de Pathways almacenen ejecutables de XLA compilados en una ubicación persistente, como Cloud Storage, para evitar la compilación redundante. Esta función está habilitada de forma predeterminada. La ubicación de la caché se pasa como una marca --gcs_scratch_location a los contenedores del administrador de recursos y del trabajador de Pathways. Para mantener al mínimo los costos de almacenamiento asociados, la caché adjunta una política de ciclo de vida a la ubicación de Cloud Storage. Hay un límite de 50 políticas por bucket de Cloud Storage. Por lo tanto, te recomendamos que uses una ubicación común de Cloud Storage en todas las cargas de trabajo.

Esta caché es similar a la caché de compilación de JAX, que pathwaysutils.initialize() inhabilita para las cargas de trabajo de Pathways.

Genera perfiles

Puedes usar el generador de perfiles de JAX para generar registros de un programa de JAX. Existen dos formas comunes de obtener asistencia con Rutas de aprendizaje:

  • Programática
    • Captura perfiles de forma programática desde tu código de JAX
  • Manual
    • Cómo capturar perfiles bajo demanda después de iniciar el servidor del generador de perfiles desde tu código de JAX

En ambos casos, los perfiles se escriben en un bucket de Cloud Storage. Se crearán varios archivos de registro en el bucket de Cloud Storage, posiblemente en diferentes carpetas de marcas de tiempo, por ejemplo:

  • Proceso principal de Python que invocó el registro (por lo general, la VM de tu notebook): <jax-client-vm-name>.xplane.pb
  • Proxy de IFRT de Pathways: client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Administrador de recursos de Rutas de aprendizaje: server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Trabajadores de rutas: server.*<tpu-node-name>.xplane.pb

Estos archivos de registro se pueden analizar con TensorBoard ejecutando el siguiente comando. Para obtener más información sobre TensorBoard y todas sus herramientas de generación de perfiles, consulta Optimiza el rendimiento de TensorFlow con el generador de perfiles.

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

Reemplaza lo siguiente:

  • BUCKET : Es un bucket de Cloud Storage para almacenar los archivos de registro.
  • PREFIX: Es una ruta de acceso dentro de tu bucket de Cloud Storage para almacenar los archivos de registro.

Captura de perfil programática

Captura un perfil desde tu código. Los perfiles se guardan en gs://<bucket>/<prefix> en un directorio de marca de tiempo.

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Captura manual de perfiles

Para capturar manualmente la información del perfil, debes iniciar el servidor del generador de perfiles desde tu código de Python:

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op

Mientras se ejecuta el servidor de Profiler, puedes capturar un perfil y exportar los datos a la ubicación de destino de Cloud Storage:

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port

Puedes encontrar información de sincronización para los métodos del cliente proxy de IFRT, como Compile y Execute, en el registro de tu programa. Estos eventos, que detallan las interacciones con el servidor de gRPC del proxy de IFRT durante la compilación y la ejecución, aparecen en el subproceso llamado GrpcClientSessionUserFuturesWorkQueue. Si examinas este subproceso en tu registro, puedes obtener estadísticas sobre el rendimiento de estas operaciones.

Marcas de XLA

Cuando usas Pathways, debes establecer las marcas de XLA en el contenedor de pathways-proxy. Puedes hacerlo con XPK o la API de PathwaysJob.

Cuando uses XPK, establece marcas de XLA como las siguientes:

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

Cuando uses la API de PathwaysJob, configura las marcas de XLA de la siguiente manera:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

Reemplaza lo siguiente:

  • USER : Tu Google Cloud nombre de usuario
  • value[n]: Son las marcas de XLA que deseas establecer.

Volcado del HLO

Para analizar en detalle las entradas del optimizador de alto nivel (HLO) que se proporcionan al compilador de XLA, puedes configurar Pathways para que vuelque el HLO en una ubicación especificada de Cloud Storage de la siguiente manera:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"

¿Qué sigue?