Debido a la naturaleza distribuida de JAX con Pathways, es posible que algunas operaciones no se escalen bien debido a la sobrecarga de comunicación. Si bien Pathways minimiza estos gastos generales con funciones como el envío asíncrono, hay algunas cosas que debes tener en cuenta cuando transfieres cargas de trabajo de JAX a Pathways o cuando escalas una carga de trabajo de JAX con Pathways a una gran cantidad de aceleradores.
Antes de comenzar
Asegúrate de tener lo siguiente:
- Herramientas de Kubernetes instaladas
- Instalaste la CLI de gcloud
- Habilitaste la API de TPU
- Habilitaste la API de Google Kubernetes Engine
Índice de proceso
JAX con Pathways trata todos los dispositivos de tu clúster de Pathways como locales. Esto simplifica la administración de dispositivos y permite que JAX utilice todos los recursos disponibles. En la práctica, esto significa lo siguiente:
jax.process_index()siempre es 0 para todos los dispositivos.jax.devices()yjax.local_devices()devuelven todos los dispositivos de TPU en todo el trabajo.
Tipo de hardware y ubicación
Para obtener el mejor rendimiento, coloca todos los componentes de Pathways y el trabajo del usuario en la misma Google Cloud zona de nube. Usa una CPU grande, como el proxy y el administrador de recursos de IFRT. Recomendamos al menos un n2-standard-64 dedicado con 64 CPU virtuales y 256 GB de memoria.
PathwaysUtils
Pathways-utils es un repositorio de GitHub basado en Python que proporciona utilidades y herramientas esenciales que te permiten optimizar la implementación y la ejecución de cargas de trabajo de JAX en la arquitectura de Pathways on Cloud. Este paquete controla las adaptaciones necesarias para el entorno de la nube, lo que permite que los desarrolladores de JAX se enfoquen en sus flujos de trabajo principales de aprendizaje automático con una configuración mínima específica de la plataforma. Específicamente, ofrece lo siguiente:
- Un backend de JAX "proxy": Este backend personalizado permite que tu aplicación de JAX use la infraestructura de Pathways configurando la variable de entorno
JAX_PLATFORMS=proxy. - Utilidades de generación de perfiles integradas: Capacidades de generación de perfiles que te permiten comprender el rendimiento de tu aplicación. Si usas las APIs de generación de perfiles de JAX estándar, como
jax.profiler.start_traceyjax.profiler.start_server, puedes generar perfiles no solo de tu código de JAX, sino también de los componentes subyacentes de Pathways, lo que proporciona una vista integral de la ejecución dentro del entorno de la nube. - Creación de puntos de control distribuidos con Orbax: Un controlador de puntos de control de Orbax personalizado que te permite usar puntos de control distribuidos y restablecer tus puntos de control cuando usas la biblioteca de Orbax en el entorno de Pathways. El objetivo de esta integración es funcionar sin requerir cambios en tu código de guardado de puntos de control de Orbax existente, siempre y cuando importe
pathwaysutils. - Primitivas de entrenamiento elástico: Proporciona primitivas de entrenamiento elástico fundamentales que puedes usar para compilar flujos de trabajo de entrenamiento sólidos y escalables con Pathways. Estas primitivas permiten que tus trabajos de entrenamiento se adapten de forma dinámica a los cambios en los recursos disponibles, lo que mejora la eficiencia y la resiliencia en los entornos de nube.
Creación de puntos de control
Orbax se probó exhaustivamente con Pathways para la creación y el restablecimiento de puntos de control distribuidos con Cloud Storage. Cuando llamas a import pathwaysutils; pathwaysutils.initialize() en tu train.py, se registra un ArrayHandler personalizado que controla de manera eficiente las operaciones de puntos de control a través del proxy de IFRT, lo que permite que los trabajadores de Pathways en aceleradores guarden y restablezcan datos directamente.
Python ubicado en el mismo lugar
Python colocalizado es una API de JAX de código abierto que te permite ejecutar código de Python especificado por el usuario directamente en los hosts de TPU o GPU, lo que es más sencillo en JAX con varios controladores. Esto permite que las tareas que requieren más procesamiento, como la carga de datos y la creación de puntos de control, eviten la transferencia de datos entre las máquinas cliente y las TPU. Para configurar tu clúster de Pathways para que ejecute la API de JAX de Python colocada, sigue las instrucciones que se indican en el README de Python colocado. Estas instrucciones explican cómo iniciar un proceso secundario de Python colocado junto con tus trabajadores de Pathways.
Carga de datos
Durante el entrenamiento, cargamos lotes de un conjunto de datos de forma repetida para alimentar el modelo. Es importante tener un cargador de datos asíncrono y eficiente que divida el lote en fragmentos entre los hosts para evitar que los aceleradores se queden sin trabajo. Cuando se ejecuta el entrenamiento con Pathways, el cargador de datos se ejecuta en una VM de CPU (a diferencia de una VM de TPU que se usa en configuraciones de varios controladores) y envía datos a las VMs de TPU. Esto genera una latencia más alta en la lectura de datos, pero se mitiga parcialmente leyendo por adelantado una cantidad X de lotes en el host de la CPU y enviando los datos leídos de forma asíncrona a las TPU. Esta solución es suficiente cuando se ejecuta a una escala pequeña o mediana.
Para obtener un rendimiento óptimo a gran escala, te recomendamos que ubiques tu canalización de datos de entrada en el mismo lugar usando Python ubicado en el mismo lugar para ejecutar tu canalización de datos directamente en los aceleradores. Esto elimina el cuello de botella de la CPU y aprovecha las interconexiones rápidas de la TPU para la transferencia de datos.
Puedes encontrar una implementación de referencia de la migración de una canalización de entrada basada en TFDS en la implementación de RemoteIterator en multihost_dataloading.py.
Esta implementación funciona tanto en JAX con varios controladores como en Pathways de forma distribuida con la API de Python JAX colocada.
Control de versiones de Jax
Las versiones de Pathways están estrechamente vinculadas con las versiones de JAX para garantizar la compatibilidad y la estabilidad. Para evitar posibles problemas, verifica que tus artefactos de Pathways y tu versión de JAX estén alineados. Cada versión de Pathways especifica claramente las versiones de JAX compatibles a través de una etiqueta con el formato jax-<version>.
Caché de compilación
La caché de compilación persistente de Pathways es una función que permite que los servidores de Pathways almacenen ejecutables de XLA compilados en una ubicación persistente, como Cloud Storage, para evitar la compilación redundante. Esta función está habilitada de forma predeterminada. La ubicación de la caché se pasa como una marca --gcs_scratch_location a los contenedores del administrador de recursos y del trabajador de Pathways. Para mantener al mínimo los costos de almacenamiento asociados, la caché adjunta una política de ciclo de vida a la ubicación de Cloud Storage. Hay un límite de 50 políticas por bucket de Cloud Storage. Por lo tanto, te recomendamos que uses una ubicación común de Cloud Storage en todas las cargas de trabajo.
Esta caché es similar a la caché de compilación de JAX, que pathwaysutils.initialize() inhabilita para las cargas de trabajo de Pathways.
Genera perfiles
Puedes usar el generador de perfiles de JAX para generar registros de un programa de JAX. Existen dos formas comunes de obtener asistencia con Rutas de aprendizaje:
- Programática
- Captura perfiles de forma programática desde tu código de JAX
- Manual
- Cómo capturar perfiles bajo demanda después de iniciar el servidor del generador de perfiles desde tu código de JAX
En ambos casos, los perfiles se escriben en un bucket de Cloud Storage. Se crearán varios archivos de registro en el bucket de Cloud Storage, posiblemente en diferentes carpetas de marcas de tiempo, por ejemplo:
- Proceso principal de Python que invocó el registro (por lo general, la VM de tu notebook):
<jax-client-vm-name>.xplane.pb - Proxy de IFRT de Pathways:
client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Administrador de recursos de Rutas de aprendizaje:
server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Trabajadores de rutas:
server.*<tpu-node-name>.xplane.pb
Estos archivos de registro se pueden analizar con TensorBoard ejecutando el siguiente comando. Para obtener más información sobre TensorBoard y todas sus herramientas de generación de perfiles, consulta Optimiza el rendimiento de TensorFlow con el generador de perfiles.
# verify trace files are present gsutil ls -l -r gs://BUCKET/PREFIX # View on tensorboard tensorboard --logdir=gs://BUCKET/PREFIX
Reemplaza lo siguiente:
BUCKET: Es un bucket de Cloud Storage para almacenar los archivos de registro.PREFIX: Es una ruta de acceso dentro de tu bucket de Cloud Storage para almacenar los archivos de registro.
Captura de perfil programática
Captura un perfil desde tu código. Los perfiles se guardan en gs://<bucket>/<prefix> en un directorio de marca de tiempo.
import jax import pathwaysutils pathwaysutils.initialize() jax.profiler.start_trace("gs://BUCKET/PREFIX") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
Captura manual de perfiles
Para capturar manualmente la información del perfil, debes iniciar el servidor del generador de perfiles desde tu código de Python:
import jax
import pathwaysutils
pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)
# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op
Mientras se ejecuta el servidor de Profiler, puedes capturar un perfil y exportar los datos a la ubicación de destino de Cloud Storage:
export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>
curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port
Puedes encontrar información de sincronización para los métodos del cliente proxy de IFRT, como Compile y Execute, en el registro de tu programa. Estos eventos, que detallan las interacciones con el servidor de gRPC del proxy de IFRT durante la compilación y la ejecución, aparecen en el subproceso llamado GrpcClientSessionUserFuturesWorkQueue. Si examinas este subproceso en tu registro, puedes obtener estadísticas sobre el rendimiento de estas operaciones.
Marcas de XLA
Cuando usas Pathways, debes establecer las marcas de XLA en el contenedor de pathways-proxy. Puedes hacerlo con XPK o la API de PathwaysJob.
Cuando uses XPK, establece marcas de XLA como las siguientes:
--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"
Cuando uses la API de PathwaysJob, configura las marcas de XLA de la siguiente manera:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customFlags: - --xla_flag_1=value1 - --xla_flag_2=value2
Reemplaza lo siguiente:
USER: Tu Google Cloud nombre de usuariovalue[n]: Son las marcas de XLA que deseas establecer.
Volcado del HLO
Para analizar en detalle las entradas del optimizador de alto nivel (HLO) que se proporcionan al compilador de XLA, puedes configurar Pathways para que vuelque el HLO en una ubicación especificada de Cloud Storage de la siguiente manera:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customEnv: - name: XLA_FLAGS value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"
¿Qué sigue?
- Crea un clúster de GKE con Pathways
- Inferencia multihost con Pathways
- Cargas de trabajo por lotes con Pathways
- Modo interactivo de Rutas de aprendizaje
- Capacitación resiliente con Pathways
- Rutas de solución de problemas