Debido a la naturaleza distribuida de JAX con Pathways, es posible que algunas operaciones no se ajusten bien debido a las sobrecargas de comunicación. Si bien Pathways minimiza estas sobrecargas con funciones como el envío asíncrono, debes tener en cuenta algunas cosas cuando transfieres cargas de trabajo de JAX a Pathways o cuando ajustas una carga de trabajo de JAX con Pathways a una gran cantidad de aceleradores.
Antes de comenzar
Asegúrate de tener lo siguiente:
- Herramientas de Kubernetes instaladas
- gcloud CLI instalada
- API de TPU habilitada
- API de Google Kubernetes Engine habilitada
Índice de procesos
JAX con Pathways trata todos los dispositivos de tu clúster de Pathways como locales. Esto simplifica la administración de dispositivos y permite que JAX utilice todos los recursos disponibles. En la práctica, esto significa lo siguiente:
jax.process_index()siempre es 0 para todos los dispositivos.jax.devices()yjax.local_devices()muestran todos los dispositivos TPU en todo el trabajo.
Tipo de hardware y ubicación compartida
Para obtener un mejor rendimiento, coloca todos los componentes de Pathways y el trabajo del usuario en la
misma Google Cloud zona de la nube. Usa una CPU grande, como el proxy IFRT y el administrador de recursos. Recomendamos al menos una n2-standard-64 dedicada que viene con 64 vCPU y 256 GB de memoria.
PathwaysUtils
Pathways-utils es un repositorio de GitHub basado en Python que proporciona herramientas y utilidades esenciales que te permiten optimizar la implementación y la ejecución de cargas de trabajo de JAX en la arquitectura de Pathways en Cloud. Este paquete controla las adaptaciones necesarias para el entorno de la nube, lo que permite que los desarrolladores de JAX se enfoquen en sus flujos de trabajo principales de aprendizaje automático con una configuración mínima específica de la plataforma. En particular, ofrece lo siguiente:
- Un backend de JAX "proxy": Este backend personalizado permite que tu aplicación de JAX use la infraestructura de Pathways configurando la variable de entorno
JAX_PLATFORMS=proxy. - Utilidades de generación de perfiles integradas: Capacidades de generación de perfiles que te permiten comprender el rendimiento de tu aplicación. Si usas las APIs de generación de perfiles de JAX estándar, como
jax.profiler.start_traceyjax.profiler.start_server, puedes generar perfiles no solo de tu código de JAX, sino también de los componentes subyacentes de Pathways, lo que proporciona una vista integral de la ejecución en el entorno de la nube. - Creación de puntos de control distribuidos con Orbax: Un controlador de puntos de control de Orbax personalizado que te permite usar puntos de control distribuidos y restablecer tus puntos de control cuando usas la biblioteca de Orbax en el entorno de Pathways. El objetivo de esta integración es funcionar sin requerir cambios en tu código de creación de puntos de control de Orbax existente, siempre que importe
pathwaysutils. - Primitivas de entrenamiento elástico: Proporciona primitivas de entrenamiento elástico fundamentales que puedes usar para compilar flujos de trabajo de entrenamiento sólidos y escalables con Pathways. Estas primitivas permiten que tus trabajos de entrenamiento se adapten de forma dinámica a los cambios en los recursos disponibles, lo que mejora la eficiencia y la resiliencia en los entornos de la nube.
Creación de puntos de control
Orbax se prueba minuciosamente con Pathways para
la creación de puntos de control distribuidos y el restablecimiento con Cloud Storage. Cuando llamas a import pathwaysutils; pathwaysutils.initialize() en tu train.py, se registra un ArrayHandler personalizado que controla de manera eficiente las operaciones de puntos de control a través del IFRT proxy, lo que permite que los trabajadores de Pathways en los aceleradores guarden y restablezcan datos directamente.
Python ubicado
Python ubicado es una API de JAX de código abierto que te permite ejecutar código de Python especificado por el usuario directamente en los hosts de TPU o GPU, lo que es más sencillo en JAX de varios controladores. Esto permite que las tareas que requieren más procesamiento, como la carga de datos y la creación de puntos de control, eviten la transferencia de datos entre el cliente y las máquinas TPU. Para configurar tu clúster de Pathways para ejecutar la API de JAX de Python ubicado, sigue las instrucciones que se indican en el archivo README de Python ubicado. En estas instrucciones, se explica cómo iniciar un sidecar de Python ubicado junto con tus trabajadores de Pathways.
Carga de datos
Durante el entrenamiento, cargamos lotes de un conjunto de datos de forma repetida para alimentar el modelo. Es importante tener un cargador de datos asíncrono y eficiente que divida el lote en fragmentos entre los hosts para evitar que los aceleradores se queden sin trabajo. Cuando se ejecuta el entrenamiento con Pathways, el cargador de datos se ejecuta en una VM de CPU (a diferencia de una VM de TPU que se usa en configuraciones de varios controladores) y envía datos a las VMs de TPU. Esto genera una latencia más alta en la lectura de datos, pero se mitiga parcialmente leyendo por adelantado X cantidad de lotes en el host de la CPU y enviando los datos leídos de forma asíncrona a las TPU. Esta solución es suficiente cuando se ejecuta a una escala pequeña o mediana.
Para obtener un rendimiento óptimo a gran escala, te recomendamos que ubiques en el mismo lugar tu canalización de datos de entrada con Python ubicado para ejecutar tu canalización de datos directamente en los aceleradores. Esto elimina el cuello de botella de la CPU y aprovecha las interconexiones rápidas de la TPU para la transferencia de datos.
Puedes encontrar una implementación de referencia de la migración de una canalización de entrada basada en TFDS
en la implementación de RemoteIterator en
multihost_dataloading.py.
Esta implementación funciona tanto en JAX de varios controladores como en Pathways de forma distribuida con la API de JAX de Python ubicado.
Control de versiones de Jax
Las versiones de Pathways están estrechamente vinculadas con las versiones de JAX para garantizar la compatibilidad y la estabilidad. Para evitar posibles problemas, verifica que tus artefactos de Pathways y tu versión de JAX estén alineados. Cada versión de Pathways especifica claramente las
versiones compatibles de JAX a través de una etiqueta con el formato jax-<version>.
Caché de compilación
La caché de compilación persistente de Pathways es una función que permite que los servidores de Pathways almacenen ejecutables de XLA compilados en una ubicación persistente, como Cloud Storage, para evitar la compilación redundante. Esta función está habilitada de forma predeterminada. La ubicación de la caché se pasa como marca --gcs_scratch_location a los contenedores de trabajador y administrador de recursos de Pathways. Para mantener los costos de almacenamiento asociados al mínimo, la caché adjunta una política de ciclo de vida a la ubicación de Cloud Storage. Existe un límite de 50 políticas por bucket de Cloud Storage. Por lo tanto, recomendamos usar una ubicación común de Cloud Storage en todas las cargas de trabajo.
Esta caché es similar a la caché de compilación de JAX
que inhabilita pathwaysutils.initialize() para las cargas de trabajo de Pathways.
Se requieren los siguientes permisos de Cloud Storage para la caché de compilación:
storage.buckets.get: Para recuperar metadatos de bucket.storage.buckets.update: Es esencial para que Pathways configure políticas de ciclo de vida de objetos para aplicar el TTL para la expulsión de la caché.storage.objects.list: Para enumerar los objetos de caché existentes dentro del bucket.storage.objects.create: Para escribir nuevos ejecutables compilados en la caché.storage.objects.get: Para leer ejecutables almacenados en caché desde el bucket.
Generación de perfiles
Puedes usar el generador de perfiles de JAX para generar seguimientos de un programa de JAX. Existen dos formas comunes compatibles con Pathways:
- Programática
- Captura perfiles de forma programática desde tu código de JAX
- Manual
- Captura perfiles a pedido después de iniciar el servidor del generador de perfiles desde tu código de JAX
En ambos casos, los perfiles se escriben en un bucket de Cloud Storage. Se crearán varios archivos de seguimiento en el bucket de Cloud Storage, posiblemente en diferentes carpetas de marcas de tiempo, por ejemplo:
- Proceso principal de Python que invocó el seguimiento (por lo general, tu VM de notebook):
<jax-client-vm-name>.xplane.pb - Proxy IFRT de Pathways:
client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Administrador de recursos de Pathways:
server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Trabajadores de Pathways:
server.*<tpu-node-name>.xplane.pb
Estos archivos de seguimiento se pueden analizar con TensorBoard ejecutando el siguiente comando. Para obtener más información sobre TensorBoard y todas sus herramientas de generación de perfiles, consulta Cómo optimizar el rendimiento de TensorFlow con el generador de perfiles.
# verify trace files are present gsutil ls -l -r gs://BUCKET/PREFIX # View on tensorboard tensorboard --logdir=gs://BUCKET/PREFIX
Reemplaza lo siguiente:
BUCKET: Un bucket de Cloud Storage para almacenar los archivos de seguimientoPREFIX: Una ruta de acceso dentro de tu bucket de Cloud Storage para almacenar los archivos de seguimiento
Captura de perfiles programática
Captura un perfil desde tu código. Los perfiles se guardan dentro de
gs://<bucket>/<prefix> en un directorio de marcas de tiempo.
import jax import pathwaysutils pathwaysutils.initialize() jax.profiler.start_trace("gs://BUCKET/PREFIX") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
Captura de perfiles manual
Para capturar información de perfiles de forma manual, debes iniciar el servidor del generador de perfiles desde tu código de Python:
import jax
import pathwaysutils
pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)
# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op
Mientras se ejecuta el servidor del generador de perfiles, puedes capturar un perfil y exportar los datos a la ubicación de destino de Cloud Storage:
export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>
curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port
Puedes encontrar información de sincronización para los métodos de cliente del proxy IFRT, como Compile y Execute, en el seguimiento de tu programa. Estos eventos, que detallan las interacciones con el servidor gRPC del proxy IFRT durante la compilación y la ejecución, aparecen en el subproceso llamado GrpcClientSessionUserFuturesWorkQueue. Si examinas este subproceso en tu seguimiento, puedes obtener estadísticas sobre el rendimiento de estas operaciones.
Marcas de XLA
Cuando usas Pathways, debes configurar las marcas de XLA en el contenedor de pathways-proxy. Puedes hacerlo con XPK o la API de PathwaysJob.
Cuando uses XPK, configura las marcas de XLA de la siguiente manera:
--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"
Cuando uses la API de PathwaysJob, configura las marcas de XLA de la siguiente manera:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customFlags: - --xla_flag_1=value1 - --xla_flag_2=value2
Reemplaza lo siguiente:
USER: tu Google Cloud nombre de usuariovalue[n]: Las marcas de XLA que deseas configurar
Volcado de HLO
Para profundizar en las entradas del optimizador de alto nivel (HLO) que se proporcionan al compilador de XLA, puedes configurar Pathways para que vuelque el HLO a una ubicación especificada de Cloud Storage de la siguiente manera:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customEnv: - name: XLA_FLAGS value: "--xla_dump_to=gs://your-bucket/your-desired-prefix/"
¿Qué sigue?
- Crea un clúster de GKE con Pathways
- Inferencia de varios hosts con Pathways
- Cargas de trabajo por lotes con Pathways
- Modo interactivo de Pathways
- Entrenamiento resistente con Pathways
- Solución de problemas de Pathways