Compila IA de producción en Cloud TPU con JAX

La pila de IA de JAX extiende el núcleo numérico de JAX con una colección de bibliotecas componibles respaldadas por Google, lo que la convierte en una plataforma de origen sólida y de extremo a extremo con código abierto para el aprendizaje automático a gran escala. Por lo tanto, la pila de IA de JAX consta de un ecosistema integral y sólido que aborda todo el ciclo de vida del AA:

  • Base a escala industrial: La pila de IA de JAX está diseñada para una escala masiva, ya que aprovecha rutas de aprendizaje de AA para organizar el entrenamiento en decenas de miles de chips y Orbax para la creación de puntos de control asíncronos resilientes con alta capacidad de procesamiento, lo que permite el entrenamiento de modelos de vanguardia a nivel de producción.

  • Kit de herramientas completo y listo para producción: La pila de IA de JAX proporciona un conjunto integral de bibliotecas para todo el proceso de desarrollo: Flax para la formación flexible de modelos, Optax para estrategias de optimización componibles y Grain para las canalizaciones de datos determinísticas esenciales para las ejecuciones reproducibles a gran escala.

  • Rendimiento máximo y especializado: Para lograr el máximo uso del hardware, la pila de IA de JAX ofrece bibliotecas especializadas, como Tokamax para kernels personalizados de vanguardia, Qwix para la cuantización no intrusiva que aumenta la velocidad de entrenamiento y de inferencia, y XProf para la creación de perfiles de rendimiento profundos y con integración de hardware.

  • Ruta completa a la producción: La pila de IA de JAX proporciona una transición fluida de la investigación a la implementación. Esto incluye MaxText como referencia escalable para el entrenamiento de modelos de base, Tunix para el aprendizaje por refuerzo (RL) y la alineación de vanguardia, y una solución de inferencia unificada con la integración de vLLM de TPU y el entorno de ejecución de la entrega de JAX.

La filosofía de la pila de IA de JAX se basa en componentes con acoplamiento bajo, cada uno de los cuales hace algo bien. En lugar de ser un framework de AA monolítico, JAX en sí tiene un alcance limitado y se enfoca en operaciones de array y transformaciones de programas eficientes. El ecosistema se basa en este framework principal para proporcionar una amplia variedad de funcionalidades relacionadas con el entrenamiento de modelos de AA y otros tipos de cargas de trabajo, como la computación científica.

Este sistema de componentes con acoplamiento bajo te permite seleccionar y combinar bibliotecas de la mejor manera para satisfacer tus requisitos. Desde la perspectiva de la ingeniería de software, esta arquitectura también te permite actualizar la funcionalidad que, tradicionalmente, se consideraría como componentes principales del framework (por ejemplo, canalizaciones de datos y creación de puntos de control) sin el riesgo de desestabilizar el framework principal ni de quedar atascarse en los ciclos de lanzamiento. Dado que la mayor parte de la funcionalidad se implementa en bibliotecas en lugar de cambios en un framework monolítico, esto hace que la biblioteca numérica principal sea más duradera y adaptable a los cambios futuros en el panorama tecnológico.

En las siguientes secciones, se proporciona una descripción general técnica de la pila de IA de JAX, sus funciones clave, las decisiones de diseño que las respaldasn y cómo se combinan para crear una plataforma duradera para las cargas de trabajo de AA modernas.

La pila de IA de JAX y otros componentes del ecosistema

Componente Función o descripción
Componentes y núcleo de la pila de IA de JAX1
JAX Cálculo de arrays y transformación de programas orientados a aceleradores (JIT, grad, vmap, pmap).
Flax Biblioteca de formación flexible de redes neuronales para la creación y modificación intuitivas de modelos.
Optax Biblioteca de transformaciones de optimización y procesamiento de gradientes componibles.
Orbax Biblioteca de puntos de control distribuidos "a cualquier escala" para la resiliencia del entrenamiento a gran escala.
Grain Biblioteca de canalizaciones de datos de entrada escalable, determinística y con capacidad de crear puntos de control.
Pila de IA de JAX: infraestructura
XLA Compilador de aprendizaje automático de código abierto para TPU, CPU y GPU.
Rutas de aprendizaje Entorno de ejecución distribuido para organizar la computación en decenas de miles de chips.
Pila de IA de JAX: des. avanzado
Pallas Una extensión de JAX para escribir kernels personalizados de bajo nivel y alto rendimiento implementados en Python.
Tokamax Una biblioteca seleccionada de kernels personalizados de alto rendimiento y vanguardia (por ejemplo, Attention).
Qwix Una biblioteca integral y no intrusiva para la cuantización (PTQ, QAT, QLoRA).
Pila de IA de JAX: aplicación
MaxText/MaxDiffusion Frameworks de referencia emblemáticos y escalables para entrenar modelos de base (por ejemplo, LLM y Diffusion).
Tunix Es un framework para la alineación y el posentrenamiento de vanguardia (RLHF, DPO).
vLLM Una solución de inferencia de LLM de alto rendimiento que usa la integración incorporada del framework de vLLM.
XProf Un generador profundo de perfiles integrado en el hardware para el análisis del rendimiento de todo el sistema.

1 Se incluye en el paquete de Python jax-ai-stack.

Figura 1: Componentes de la pila y el ecosistema de IA de JAX

Pila de IA de JAX

El imperativo arquitectónico: rendimiento más allá de los frameworks

A medida que las arquitecturas de modelos convergen (por ejemplo, en los Transformers de mezcla de expertos [MoE] multimodales), la búsqueda del máximo rendimiento está dando lugar a la aparición de los megakernels. Un megakernel es, de hecho, todo el pase hacia delante (o una gran parte) de un modelo específico, codificado de forma manualmente con una API de nivel inferior, como el SDK de CUDA en las GPU de NVIDIA. Este enfoque logra la máxima utilización del hardware superponiendo de forma agresiva el procesamiento, la memoria y la comunicación. El trabajo reciente de la comunidad de investigación demostró que este enfoque puede generar ganancias significativas en la capacidad de procesamiento, más del 22% en algunos casos, para la inferencia en GPU. Esta tendencia no se limita a la inferencia. La evidencia indica que algunos esfuerzos de entrenamiento a gran escala implicaron el control de hardware de bajo nivel para lograr ganancias sustanciales en la eficiencia.

Si esta tendencia se acelera, todos los frameworks de alto nivel como existen hoy corren el riesgo de volverse menos relevantes, ya que el acceso de bajo nivel al hardware es lo que, en última instancia, importa para el rendimiento en arquitecturas estables y maduras. Esto representa un desafío para todas las pilas de AA modernas: ¿cómo proporcionar un control de hardware de nivel experto sin sacrificar la productividad y la flexibilidad de un framework de alto nivel?

Para que las TPU proporcionen una ruta clara hacia este nivel de rendimiento, el ecosistema debe exponer una capa de API más cercana al hardware, lo que permite el desarrollo de estos kernels altamente especializados. La pila de JAX está diseñada para resolver este problema ofreciendo un continuo de abstracción (consulta la figura 2), desde las optimizaciones automatizadas de alto nivel del compilador XLA hasta el control manual y detallado de la biblioteca de formación de kernels de Pallas.

Figura 2: El continuo de abstracción de JAX

Continuo de abstracción de JAX

La pila principal de IA de JAX

La pila principal de IA de JAX consta de cinco bibliotecas clave que proporcionan la base para el desarrollo de modelos:

JAX: Una base para la transformación de programas componibles y de alto rendimiento

JAX es una biblioteca de Python para la transformación de programas y el procesamiento de arrays orientados a aceleradores. Está diseñada para la computación numérica de alto rendimiento y el aprendizaje automático a gran escala. Con su modelo de programación funcional y su API similar a NumPy, JAX proporciona una base sólida para las bibliotecas de nivel superior.

Con su diseño basado en el compilador, JAX promueve la escalabilidad de forma inherente aprovechando XLA (consulta la sección de XLA) para el análisis, la optimización y la segmentación de hardware agresivos de todo el programa. El énfasis de JAX en la programación funcional (por ejemplo, funciones puras) hace que sus transformaciones de programas centrales sean más manejables y, lo que es fundamental, componibles.

Estas transformaciones principales se pueden combinar para lograr un alto rendimiento y escalabilidad de las cargas de trabajo en función del tamaño del modelo y el clúster, y los tipos de hardware:

  • jit: compilación justo a tiempo de funciones de Python en ejecutables de XLA optimizados y fusionados
  • grad: diferenciación automática, compatible con el modo directo y el modo inverso, así como con derivadas de orden superior
  • vmap: vectorización automática que permite el procesamiento por lotes y el paralelismo de datos sin problemas, sin modificar la lógica de la función
  • pmap o shard_map: paralelización automática en varios dispositivos (por ejemplo, núcleos de TPU), que constituye la base del entrenamiento distribuido

La integración perfecta con el modelo GSPMD (SPMD de uso general) de XLA permite que JAX paralelice de forma automática los cálculos en pods de TPU grandes con cambios mínimos en el código. En la mayoría de los casos, el escalamiento solo requiere anotaciones de fragmentación de alto nivel.

Flax: Formación flexible de redes neuronales

Flax simplifica la creación, la depuración y el análisis de redes neuronales en JAX, ya que proporciona un enfoque intuitivo y orientado a objetos para la compilación de modelos. Si bien la API funcional de JAX es potente, ofrece una abstracción basada en capas más familiar para los desarrolladores acostumbrados a frameworks como PyTorch, sin ninguna penalización en el rendimiento.

Este diseño simplifica la modificación o la combinación de componentes de modelos entrenados. Las técnicas como LoRA y la cuantización requieren definiciones de modelos manipulables, que la API de NNX de Flax proporciona a través de una interfaz de Python. NNX encapsula el estado del modelo, reduce la carga cognitiva del usuario y permite el recorrido y la modificación programáticos de la jerarquía del modelo.

Principales fortalezas:

  • API intuitiva orientada a objetos: Simplifica la construcción de modelos y permite casos de uso avanzados, como el reemplazo de submódulos y la inicialización parcial.
  • Coherente con Core JAX: Flax proporciona transformaciones elevadas que son completamente compatibles con el paradigma funcional de JAX, lo que ofrece el rendimiento completo de JAX con una mayor facilidad de uso para los desarrolladores.

Optax: Estrategias de optimización y procesamiento de gradientes componibles

Optax es una biblioteca de optimización y procesamiento de gradientes para JAX. Está diseñada para proporcionar a los creadores de modelos componentes básicos que se pueden recombinar de formas personalizadas para entrenar modelos de aprendizaje profundo, entre otras aplicaciones. Se basa en las capacidades de la biblioteca principal de JAX para proporcionar una biblioteca de alto rendimiento bien probada de funciones de pérdida y optimizador, y técnicas asociadas que se pueden usar para entrenar modelos de AA.

Motivación

El cálculo y la minimización de las pérdidas son fundamentales para permitir el entrenamiento de los modelos de AA. Con su compatibilidad con la diferenciación automática, la biblioteca principal de JAX proporciona las capacidades numéricas para entrenar modelos, pero no ofrece implementaciones estándar de optimizadores populares (por ejemplo, RMSProp o Adam) ni de pérdidas (por ejemplo, CrossEntropy o MSE). Si bien podrías implementar estas funciones (y algunos desarrolladores avanzados optarán por hacerlo), un error en la implementación de un optimizador podría generar problemas de calidad del modelo difíciles de diagnosticar. En lugar de que el usuario implemente estas partes esenciales, Optax proporciona implementaciones de estos algoritmos que se prueban para garantizar su corrección y rendimiento.

El campo de la teoría de la optimización se encuentra directamente en el dominio de la investigación, pero su rol central en el entrenamiento también lo convierte en una parte indispensable del entrenamiento de modelos de AA de producción. Una biblioteca que cumpla con este rol debe ser lo suficientemente flexible para adaptarse a las iteraciones rápidas de investigación, además de lo suficientemente sólida y eficiente para ser confiable en el entrenamiento de modelos de producción. También debe proporcionar implementaciones bien probadas de algoritmos de vanguardia que coincidan con las ecuaciones estándar. La biblioteca de Optax, a través de su arquitectura modular componible y su énfasis en el código correcto y legible, está diseñada para lograr este objetivo.

Diseño

Optax está diseñada para mejorar la velocidad de la investigación y la transición de esta a la producción, ya que proporciona implementaciones legibles, bien probadas y eficientes de algoritmos centrales. Optax tiene usos más allá del contexto del aprendizaje profundo. Sin embargo, en este contexto, se puede ver como una colección de funciones de pérdida, algoritmos de optimización y transformaciones de gradientes conocidos que se implementan de forma puramente funcional en línea con la filosofía de JAX. La colección de pérdidas y optimizadores conocidos permite a los usuarios comenzar con facilidad y confianza.

El enfoque modular de Optax te permite encadenar varios optimizadores seguidos de otras transformaciones comunes (por ejemplo, el recorte de gradientes) y encapsularlos con técnicas comunes, como MultiStep o Lookahead, para lograr potentes estrategias de optimización con unas pocas líneas de código. La interfaz flexible te permite investigar nuevos algoritmos de optimización y usar técnicas de optimización de segundo orden potentes, como Shampoo o Muon.

# Optax implementation of a RMSProp optimizer with a custom learning rate
#  schedule, gradient clipping and gradient accumulation.
optimizer = optax.chain(
  optax.clip_by_global_norm(GRADIENT_CLIP_VALUE),
  optax.rmsprop(learning_rate=optax.cosine_decay_schedule(init_value=lr,decay_steps=decay)),
  optax.apply_every(k=ACCUMULATION_STEPS)
)

# The same thing, in PyTorch
optimizer = optim.RMSprop(model_params, lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TOTAL_STEPS)
for i, (inputs, targets) in enumerate(data_loader):
    # ... Training loop body ...
    if (i + 1) % ACCUMULATION_STEPS == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VALUE)
        optimizer.step()
        scheduler.step()
 optimizer.zero_grad()

En el fragmento de código anterior, se muestra cómo configurar un optimizador con una tasa de aprendizaje personalizada, un recorte de gradiente y una acumulación de gradiente.

Principales fortalezas

  • Biblioteca sólida: Proporciona una biblioteca integral de pérdidas, optimizadores y algoritmos con un enfoque en la corrección y la legibilidad.
  • Transformaciones encadenables modulares: Esta API flexible te permite crear estrategias de optimización potentes y complejas de forma declarativa, sin modificar el bucle de entrenamiento.
  • Funcional y escalable: Las implementaciones puramente funcionales se integran sin problemas con los mecanismos de paralelización de JAX (por ejemplo, pmap), lo que te permite usar el mismo código para escalar desde un solo host hasta clústeres grandes.

Orbax o TensorStore: creación de puntos de control distribuidos a gran escala

Orbax es una biblioteca de puntos de control para JAX diseñada para cualquier escala, desde el entrenamiento en un solo dispositivo hasta el distribuido a gran escala. Su objetivo es unificar las implementaciones fragmentadas de la creación de puntos de control y ofrecer funciones de rendimiento esenciales, como la creación de puntos de control asíncrona y de varios niveles, a un público más amplio. Orbax permite la resiliencia necesaria para los trabajos de entrenamiento masivos y proporciona un formato flexible para publicar puntos de control.

A diferencia de los sistemas generalizados de punto de control y restablecimiento que crean instantáneas de todo el estado del sistema, la creación de puntos de control de AA con Orbax persiste de forma selectiva solo la información esencial para reanudar el entrenamiento de los pesos del modelo, el estado del optimizador y el estado del cargador de datos. Este enfoque específico minimiza el tiempo de inactividad del acelerador. Orbax logra esto superponiendo las operaciones de E/S con el procesamiento, una función esencial para las cargas de trabajo grandes. El tiempo de inactividad de los aceleradores de tiempo se reduce a la duración de la transferencia de datos del dispositivo al host, que se puede superponer aún más con el siguiente paso de entrenamiento, lo que hace que la creación de puntos de control sea casi gratuito desde una perspectiva de rendimiento.

En esencia, Orbax usa TensorStore para la lectura y escritura eficientes y paralelas de datos de array. La API de Orbax abstrae esta complejidad y ofrece una interfaz fácil de usar para controlar los PyTrees, que son la representación estándar de los modelos en JAX.

Principales fortalezas:

  • Adopción generalizada: Con millones de descargas mensuales, Orbax sirve como medio común para compartir artefactos de AA.
  • Simplifica las complejidades: Orbax abstrae las complejidades de la creación de puntos de control distribuidos, incluido el guardado asíncrono, la atomicidad y los detalles del sistema de archivos.
  • Flexible: Si bien ofrece APIs para casos de uso comunes, Orbax te permite personalizar tu flujo de trabajo para satisfacer requisitos especializados.
  • Rendimiento y escalabilidad: Las funciones como la creación de puntos de control asíncrona, un formato de almacenamiento eficiente (OCDBT) y las estrategias inteligentes de carga de datos garantizan que Orbax escale para las rondas de entrenamiento que involucran decenas de miles de nodos.

Grain: Canalizaciones de datos de entrada determinísticas y escalables

Grain es una biblioteca de Python para leer y procesar datos para entrenar y evaluar modelos de JAX. Es flexible, rápida y determinística, y admite funciones avanzadas, como la creación de puntos de control, que son esenciales para entrenar con éxito cargas de trabajo grandes. Es compatible con formatos de datos y backends de almacenamiento populares, y también proporciona una API flexible para extender la compatibilidad a formatos y backends específicos del usuario que no son compatibles de forma nativa. Si bien Grain está diseñado principalmente para funcionar con JAX, es independiente del framework, no requiere JAX para ejecutarse y también se puede usar con otros frameworks.

Motivación

Las canalizaciones de datos son una parte fundamental de la infraestructura de entrenamiento. Deben ser flexibles para que las transformaciones comunes se puedan expresar de manera eficiente y lo suficientemente rápidas para que los aceleradores estén ocupados en todo momento. También deben poder admitir varios formatos de almacenamiento y backends. Debido a sus tiempos de pasos más altos, el entrenamiento de modelos grandes a gran escala impone requisitos adicionales en la canalización de datos más allá de los que requieren las cargas de trabajo de entrenamiento normales, principalmente centrados en el determinismo y la reproducibilidad.2 La biblioteca de Grain se diseñó con una arquitectura flexible que aborda estas necesidades.


2 En el artículo 5.1 del artículo de PaLM, los autores señalan que observaron aumentos repentinos de pérdidas muy grandes a pesar de tener habilitado el recorte de gradientes. La solución fue quitar los lotes de datos infractores y reiniciar el entrenamiento desde un punto de control anterior al aumento repentino de la pérdida. Esto solo es posible con una configuración de entrenamiento completamente determinística y reproducible.

Diseño

En el nivel más alto, hay dos formas de estructurar una canalización de entrada, como un clúster independiente de trabajadores de datos o colocando a los trabajadores de datos en los hosts que controlan los aceleradores. Grain elige la segunda opción por varios motivos.

Los aceleradores se combinan con hosts potentes que suelen estar inactivos durante los pasos de entrenamiento, lo que los convierte en una opción natural para ejecutar la canalización de datos de entrada. Esta implementación tiene otras ventajas: proporciona coherencia en tu vista de la fragmentación de datos en la entrada y el procesamiento, lo que simplifica el proceso. Se podría argumentar que colocar el trabajador de datos en el host del acelerador corre el riesgo de saturar la CPU del host. Sin embargo, esto no impide descargar las transformaciones que requieren mucha capacidad de procesamiento en otro clúster con RPC.3

En cuanto a la API, con una implementación pura en Python que admite varios procesos y una API flexible, Grain te permite implementar transformaciones de datos arbitrariamente complejas redactando etapas de canalización basadas en paradigmas de transformación bien comprendidos.

Listo para usar, Grain admite formatos de datos de acceso aleatorio eficientes, como ArrayRecord y Bagz, junto con otros formatos de datos populares, como Parquet y TFDS. Grain incluye compatibilidad para leer desde sistemas de archivos locales, así como desde Cloud Storage de forma predeterminada. Además de admitir formatos y backends de almacenamiento populares, una abstracción limpia de la capa de almacenamiento te permite agregar compatibilidad con tus fuentes de datos existentes o encapsularlas para que sean compatibles con la biblioteca de Grain.


3 Así es como deben operar las canalizaciones de datos multimodales: los tokenizadores de imágenes y audio, por ejemplo, son modelos que se ejecutan en sus propios clústeres en sus propios aceleradores, y las canalizaciones de entrada realizarían llamadas RPC para convertir los ejemplos de datos en flujos de tokens.

Principales fortalezas

  • Feed de datos determinístico: La colocación del trabajador de datos con el acelerador y su vinculación con iteradores con puntos de control y un Shuffle global estable permiten que el estado del modelo y el de la canalización de datos se creen juntos en una instantánea coherente con Orbax. Esto mejora el determinismo del proceso de entrenamiento.
  • APIs flexibles para habilitar transformaciones de datos potentes: Una API de transformaciones flexible y pura de Python te permite realizar transformaciones de datos extensas dentro de la canalización de procesamiento de entrada.
  • Compatibilidad extensible con varios formatos y backends: Una API de fuentes de datos extensible admite formatos y backends de almacenamiento populares, y te permite agregar compatibilidad con formatos y backends nuevos.
  • Interfaz de depuración potente: Las herramientas de visualización de la canalización de datos y un modo de depuración te permiten inspeccionar, depurar y optimizar el rendimiento de tus canalizaciones de datos.

La pila de IA extendida de JAX

Más allá de la pila principal, un ecosistema enriquecido de bibliotecas especializadas proporciona la infraestructura, las herramientas avanzadas y las soluciones con capa de aplicación necesarias para el desarrollo del AA de extremo a extremo.

Infraestructura fundamental: Compiladores y entornos de ejecución

XLA: El motor independiente del hardware y centrado en el compilador

Motivación

XLA o Accelerated Linear Algebra es el compilador específico del dominio de Google, que está bien integrado en JAX y admite dispositivos de hardware de CPU, GPU y TPU. XLA se diseñó para ser un generador de código independiente del hardware que se orienta a las TPU, GPU y CPU.

El diseño basado en el compilador del de XLA es una elección arquitectónica fundamental que crea una ventaja duradera en un entorno de investigación en rápida evolución. En cambio, el enfoque predominante centrado en el kernel en otros ecosistemas se basa en bibliotecas optimizadas manualmente para el rendimiento. Si bien esto es muy eficaz para las arquitecturas de modelos estables y bien establecidas, crea un cuello de botella para la innovación. Cuando una nueva investigación ingresa arquitecturas novedosas, el ecosistema debe esperar a que se escriban y optimicen nuevos kernels. Sin embargo, nuestro diseño centrado en el compilador suele poder generalizarse para abarcar nuevos patrones, lo que proporciona una ruta de alto rendimiento para la investigación de vanguardia desde el primer día.

Diseño

XLA funciona compilando justo a tiempo (JIT) los grafos de procesamiento que JAX genera durante su proceso de seguimiento (por ejemplo, cuando una función se decora con @jax.jit).

Esta compilación sigue una canalización de varias etapas:

  1. Grafo de procesamiento de JAX
  2. Optimizador de alto nivel (HLO)
  3. Optimizador de bajo nivel (LLO)
  4. Código de hardware
  • Del grafo de JAX a HLO: El grafo de procesamiento de JAX se convierte en la representación de HLO de XLA. En este nivel superior, se aplican optimizaciones potentes y no específicas del hardware, como la fusión de operadores y la administración eficiente de la memoria. El dialecto StableHLO sirve como una interfaz duradera y con versiones para esta etapa.
  • De HLO a LLO: Después de las optimizaciones de alto nivel, los backends específicos del hardware toman el control y reducen la representación de HLO a un LLO orientado a la máquina.
  • De LLO al código de hardware: Finalmente, el LLO se compila en código máquina altamente eficiente. En el caso de las TPU, este código se agrupa en paquetes de palabras de instrucción muy largas (VLIW) que se envían directamente al hardware.

Para el escalamiento, el diseño de XLA se basa en el paralelismo. Emplea algoritmos para usar al máximo las unidades de multiplicación de matrices (MXU) en un chip. Entre los chips, XLA usa SPMD (datos múltiples de un solo programa), una técnica de paralelización basada en el compilador que usa un solo programa en todos los dispositivos. Este potente modelo se expone a través de las APIs de JAX, lo que te permite administrar el paralelismo de datos, modelos o canalizaciones con anotaciones de fragmentación de alto nivel.

Para patrones de paralelismo más complejos, también son posibles los datos múltiples de programas múltiples (MPMD), y bibliotecas como PartIR:MPMD permiten que los usuarios de JAX también proporcionen anotaciones de MPMD.

Principales fortalezas
  • Compilación: La compilación justo a tiempo del grafo de procesamiento permite optimizar el diseño de la memoria, la asignación de búferes y la administración de la memoria. Las alternativas, como las metodologías basadas en el kernel, imponen esa carga al desarrollador. En la mayoría de los casos, XLA puede lograr un rendimiento excelente sin comprometer la velocidad del desarrollador.
  • Paralelismo: XLA implementa varias formas de paralelismo con SPMD, y esto se expone a nivel de JAX. Esto te permite expresar estrategias de fragmentación, lo que permite la experimentación y la escalabilidad de los modelos en miles de chips.

Ruta de aprendizaje: Un entorno de ejecución unificado para el procesamiento distribuido a gran escala

La ruta de aprendizaje ofrece abstracciones para el entrenamiento y la inferencia distribuidos con tolerancia a errores y recuperación integradas, lo que permite a los investigadores de AA codificar como si estuvieran usando una sola máquina potente.

Motivación

Para poder entrenar y, luego, implementar modelos grandes, se necesitan cientos o miles de chips. Estos chips se distribuyen en numerosos racks y máquinas anfitrión. Un trabajo de entrenamiento es un programa síncrono a gran escala que requiere que todos estos chips y sus respectivos hosts trabajen en conjunto en cálculos de XLA que se hayan paralelizado (fragmentado). En el caso de los modelos de lenguaje grandes, que pueden necesitar más de decenas de miles de chips, este servicio debe ser capaz de abarcar varios Pods en una estructura de centro de datos, además de usar estructuras de interconexión entre chips (ICI) y de interconexión en el chip (OCI) dentro de un Pod.

Diseño

La ruta de aprendizaje de AA es el sistema que usamos para coordinar los cálculos distribuidos en hosts y chips TPU. Está diseñada para brindar escalabilidad y eficiencia en cientos de miles de aceleradores. Para el entrenamiento a gran escala, proporciona un solo cliente de Python para varios trabajos de Pod, integración de Megascale XLA, servicio de compilación y Python remoto. También admite el paralelismo entre porciones y la tolerancia a la interrupción, lo que permite la recuperación automática de las interrupciones de recursos.

La ruta de aprendizaje incorpora colectivos optimizados entre hosts que permiten que los grafos de procesamiento de XLA se extiendan más allá de un solo pod de TPU. Expande la compatibilidad de XLA con el paralelismo de datos, modelos y canalizaciones para trabajar en los límites de las porciones de TPU a través de la red del centro de datos (DCN) con la integración de un entorno de ejecución distribuido que administra la comunicación de la DCN con las primitivas de comunicación de XLA.

Principales fortalezas

La arquitectura de un solo controlador, integrada con JAX, es una abstracción clave. Permite a los investigadores explorar diversas estrategias de paralelismo y fragmentación para el entrenamiento y la implementación, mientras se escala a decenas de miles de chips con facilidad.

Desarrollo avanzado: Rendimiento, datos y eficiencia

Pallas: Escribe kernels personalizados de alto rendimiento en JAX

Si bien JAX prioriza el compilador, hay situaciones en las que es posible que desees tener un control detallado sobre el hardware para lograr el máximo rendimiento. Pallas es una extensión de JAX que permite escribir kernels personalizados para GPU y TPU. Su objetivo es proporcionar un control preciso sobre el código generado, combinado con la ergonomía de alto nivel del seguimiento de JAX y la API de jax.numpy.

Pallas expone un modelo de paralelismo basado en una cuadrícula en el que se inicia una función de kernel definida por el usuario en una cuadrícula multidimensional de grupos de trabajo paralelos. Permite la administración explícita de la jerarquía de memoria, ya que te permite definir cómo se segmentan los tensores y se transfieren entre la memoria más lenta y más grande (por ejemplo, HBM), y la memoria más rápida y más pequeña en el chip (por ejemplo, VMEM en TPU, memoria compartida en GPU), con mapas de índices para asociar ubicaciones de cuadrícula con bloques de datos específicos. Pallas puede reducir la misma definición del kernel para que se ejecute de manera eficiente en las TPU de Google y en varias GPU. Para ello, compila los kernels en una representación intermedia adecuada para la arquitectura de destino: mosaico para las TPU o tecnologías como Triton para las GPU. Con Pallas, puedes escribir kernels de alto rendimiento que especialicen bloques, como la atención, para lograr el mejor rendimiento del modelo en el hardware objetivo sin necesidad de depender de los kits de herramientas específicos del proveedor.

Tokamax: Una biblioteca seleccionada de kernels de vanguardia

Si Pallas es una herramienta para formar kernels, Tokamax es una biblioteca de kernels de aceleradores personalizados de vanguardia que admiten tanto TPU como GPU. Tokamax se basa en JAX y Pallas, y te permite usar toda la potencia de tu hardware. También proporciona herramientas para que compiles y ajustes de forma automática kernels personalizados.

Motivación

JAX, con sus raíces en XLA, es un framework basado en el compilador. Sin embargo, existe un conjunto reducido de casos en los que es posible que debas tomar el control directo del hardware para lograr el máximo rendimiento.4 Los kernels personalizados son fundamentales para obtener el mejor rendimiento de los recursos costosos de los aceleradores de AA, como las TPU y GPU. Si bien se emplean ampliamente para permitir la ejecución eficiente de operadores clave, como Attention, su implementación requiere una comprensión profunda tanto del modelo como de la arquitectura de hardware objetivo. Tokamax proporciona una fuente autorizada de kernels seleccionados, bien probados y de alto rendimiento, junto con una infraestructura compartida sólida para su desarrollo y mantenimiento, y administración del ciclo de vida. Esta biblioteca también puede actuar como una implementación de referencia para que la uses como base y la personalices según sea necesario. Esto te permite enfocarte en tus esfuerzos de modelado sin tener que preocuparte por la infraestructura.


4 Este es un paradigma bien establecido y tiene precedentes en el mundo de la CPU, en el que el código compilado constituye la mayor parte del programa y los desarrolladores recurren a funciones intrínsecas o al ensamble intercalado para optimizar las secciones fundamentales para el rendimiento.

Diseño

Para cualquier kernel determinado, Tokamax proporciona una API común que puede estar respaldada por varias implementaciones. Por ejemplo, los kernels de TPU se pueden implementar con la reducción estándar de XLA o de forma explícita con Pallas/Mosaico-TPU. Los kernels de GPU se pueden implementar con la reducción estándar de XLA, con Mosaico-GPU o con Triton. De forma predeterminada, la API de Tokamax elige la implementación más conocida para una configuración determinada, según los resultados almacenados en caché de las ejecuciones periódicas de ajuste automático y comparativas, aunque puedes elegir implementaciones específicas si es necesario. Con el tiempo, se pueden agregar nuevas implementaciones para aprovechar mejor las funciones específicas de las nuevas generaciones de hardware y obtener un rendimiento aún mejor.

Un componente clave de la biblioteca de Tokamax, más allá de los kernels en sí, es la infraestructura de asistencia que te permite escribir kernels personalizados. Por ejemplo, la infraestructura de ajuste automático te permite definir un conjunto de parámetros configurables (por ejemplo, tamaños de mosaico) en los que Tokamax puede realizar un análisis exhaustivo para determinar y almacenar en caché los mejores parámetros de configuración ajustados posibles. Las regresiones nocturnas te protegen de problemas inesperados de rendimiento y numéricos causados por cambios en la infraestructura subyacente del compilador o en otras dependencias.

Principales fortalezas
  • Experiencia fluida para desarrolladores: Una biblioteca unificada y seleccionada proporciona implementaciones conocidas y de alto rendimiento de kernels clave, con expresiones claras de las generaciones de hardware compatibles y el rendimiento esperado, tanto de forma programática como en la documentación. Esto minimiza la fragmentación y la deserción.
  • Flexibilidad y administración del ciclo de vida: Puedes elegir diferentes implementaciones, incluso cambiarlas con el tiempo si es necesario. Por ejemplo, si el compilador XLA mejora la compatibilidad con ciertas operaciones y ya no requiere kernels personalizados, existe una ruta de baja y migración.
  • Extensibilidad: Puedes implementar tus propios kernels y aprovechar la infraestructura compartida de gran compatibilidad, lo que te permite enfocarte en las capacidades y optimizaciones de valor agregado. Las implementaciones estándar claramente creadas sirven como punto de partida para que los usuarios aprendan y las extiendan.

Qwix: Cuantización integral y no intrusiva

Qwix es una biblioteca de cuantización integral para la pila de IA de JAX que admite tanto LLM como otros tipos de modelos en todas las etapas, incluido el entrenamiento (entrenamiento con reconocimiento de cuantización [QAT], técnica de cuantización [QT], adaptación de bajo rango cuantificada [QLoRA]) y la inferencia posterior al entrenamiento (PTQ), dirigida a XLA y los entornos de ejecución integrados en el dispositivo.

Motivación

Las bibliotecas de cuantización existentes, en particular en el ecosistema de PyTorch, suelen tener propósitos limitados (por ejemplo, solo PTQ o solo QLoRA). Este panorama fragmentado te obliga a cambiar de herramientas, lo que impide el uso coherente del código y la coincidencia numérica precisa entre el entrenamiento y la inferencia. Además, muchas soluciones requieren modificaciones sustanciales del modelo, lo que acopla estrechamente la lógica del modelo a la de cuantización.

Diseño

La filosofía de diseño de Qwix enfatiza una solución integral y, de manera fundamental, una integración de modelos no intrusiva. Se diseñó con una arquitectura jerárquica y extensible basada en APIs funcionales reutilizables.

Esta integración no intrusiva se logra a través de un mecanismo de interceptación diseñado de forma meticulosa que redirecciona las funciones de JAX a sus contrapartes cuantificadas. Esto te permite integrar tus modelos sin ninguna modificación, lo que desacopla por completo el código de cuantización de las definiciones del modelo.

En el siguiente ejemplo, se muestra cómo aplicar la cuantificación w4a4 (activación de 4 bits y peso de 4 bits) a las capas de MLP de un LLM y la cuantificación w8 (peso de 8 bits) al incorporador. Para cambiar la receta de cuantización, solo debes actualizar la lista de reglas.

fp_model = ModelWithoutQuantization(...)
rules = [
    qwix.QuantizationRule(
        module_path=r'embedder',
        weight_qtype='int8',
    ),
    qwix.QuantizationRule(
        module_path=r'layers_\d+/mlp',
        weight_qtype='int4',
        act_qtype='int4',
        tile_size=128,
        weight_calibration_method='rms,7',
    ),
]
quantized_model = qwix.quantize_model(fp_model, qwix.PtqProvider(rules))
Principales fortalezas
  • Solución integral: Qwix se puede aplicar de forma general en numerosas situaciones de cuantificación, lo que garantiza un uso coherente del código entre el entrenamiento y la inferencia.
  • Integración de modelos no intrusiva: Como muestra el ejemplo, puedes integrar modelos con una sola línea de código. Esto te permite usar hiperparámetros en muchos esquemas de cuantización para encontrar el mejor equilibrio entre calidad y rendimiento.
  • Federada con otras bibliotecas: Qwix se integra perfectamente con la pila de IA de JAX. Por ejemplo, Tokamax se adapta de forma automática para usar versiones cuantificadas de kernels, sin código adicional del usuario, cuando el modelo se cuantifica con Qwix.
  • Apta para la investigación: Las APIs fundamentales y la arquitectura extensible de Qwix permiten a los investigadores explorar nuevos algoritmos y facilitan las comparaciones directas con herramientas integradas de evaluación y comparativas.

Capa de aplicación: entrenamiento y alineación

Entrenamiento de modelos de base: MaxText y MaxDiffusion

MaxText y MaxDiffusion son los frameworks de entrenamiento de modelos insignia LLM y de difusión de Google, respectivamente. Estos repositorios contienen una selección de implementaciones altamente optimizadas de modelos populares de código abierto. Cumplen un doble propósito: funcionan como una base de código de entrenamiento de modelos lista para usar y como una referencia que los compiladores de modelos de base pueden usar para la compilación.

Motivación

En toda la industria, se observa un rápido crecimiento del interés en entrenar modelos de IA generativa. La popularidad de los modelos abiertos aceleró esta tendencia, ya que proporcionó arquitecturas probadas. El entrenamiento y la adaptación de estos modelos requieren un alto rendimiento, eficiencia, escalabilidad para una gran cantidad de chips y código claro y comprensible. MaxText y MaxDiffusion son soluciones integrales que se pueden usar en TPU o GPU, y están diseñadas para satisfacer estas necesidades.

Diseño

MaxText y MaxDiffusion son bases de código de modelos de base diseñadas teniendo en cuenta la legibilidad y el rendimiento. Están estructurados con componentes reutilizables y bien probados: definiciones de modelos que usan kernels personalizados (como Tokamax) para obtener el máximo rendimiento, un arnés de entrenamiento para la organización y la supervisión, y un potente sistema de configuración que te permite controlar detalles, como la fragmentación y la cuantificación (con Qwix), a través de una interfaz intuitiva. Se incorporan funciones avanzadas de confiabilidad, como la creación de puntos de control de varios niveles, para garantizar un rendimiento adecuado sostenido.

MaxText y MaxDiffusion usan las mejores bibliotecas de JAX de su clase, Qwix, Tunix, Orbax y Optax para ofrecer capacidades principales. Estas bibliotecas proporcionan una infraestructura sólida y escalable, lo que reduce la sobrecarga de desarrollo y te permite enfocarte en la tarea de modelado. Para la inferencia, se comparte el código del modelo para permitir una entrega eficiente y escalable.

Principales fortalezas
  • Rendimiento por diseño: Con una infraestructura de entrenamiento configurada para un alto "goodput" (capacidad de procesamiento útil) y con implementaciones de modelos optimizadas para un alto MFU (utilización de flops del modelo), MaxText y MaxDiffusion ofrecen un alto rendimiento a gran escala listo para usar.
  • Diseñados para el escalamiento: Estos frameworks aprovechan la potencia de la pila de IA de JAX (en especial, la ruta de aprendizaje) y te permiten escalar sin problemas de decenas a decenas de miles de chips.
  • Base sólida para los compiladores de modelos de base: Las implementaciones legibles y de alta calidad sirven como un punto de partida sólido para que los desarrolladores las usen como una solución integral o como una implementación de referencia para sus propias personalizaciones.

Alineación y entrenamiento posteriores: el framework de Tunix

Tunix ofrece algoritmos de aprendizaje por refuerzo (RL) de código abierto de vanguardia, junto con un framework de trabajo y una infraestructura sólidos, lo que proporciona una ruta optimizada para que los desarrolladores experimenten con técnicas de posentrenamiento de LLM, incluido el ajuste supervisado (SFT) y la alineación con JAX y TPU.

Motivación

El posentrenamiento es un paso fundamental para desbloquear el verdadero poder de los LLM. La etapa de aprendizaje por refuerzo (RL) es esencial para desarrollar capacidades de alineación y razonamiento. El desarrollo de código abierto en esta área se basó casi exclusivamente en PyTorch y las GPU, lo que dejó una brecha importante para las soluciones de JAX y TPU. Tunix (ajuste en JAX) es una biblioteca de alto rendimiento nativa de JAX diseñada para cubrir esta brecha.

Diseño

Diagrama de Tunix

Desde la perspectiva del framework, Tunix permite una configuración de vanguardia que separa claramente los algoritmos de RL de la infraestructura. Ofrece una API liviana, similar a la de un cliente, que oculta la complejidad de la infraestructura de RL, lo que te permite desarrollar algoritmos nuevos. Tunix proporciona soluciones listas para usar para algoritmos populares, como Proximal Policy Optimization (PPO), Direct Preference Optimization (DPO) y otros.

En cuanto a la infraestructura, Tunix se integra con la ruta de aprendizaje, lo que permite una arquitectura de un solo controlador que hace que el entrenamiento de RL con varios nodos sea accesible. En cuanto al entrenamiento, Tunix admite de forma nativa el entrenamiento eficiente en parámetros (por ejemplo, LoRA) y aprovecha la fragmentación de JAX y XLA (paralelización general y escalable para el grafo de procesamiento de AA [GSPMD]) para generar un grafo de procesamiento de alto rendimiento. Es compatible con modelos populares de código abierto, como Gemma y Llama, listos para usar.

Principales fortalezas
  • Simplicidad: Proporciona una API de alto nivel similar a la de un cliente que abstrae las complejidades de la infraestructura distribuida subyacente.
  • Eficiencia del desarrollador: Tunix acelera el ciclo de vida de I+D con algoritmos y "recetas" integrados, lo que te proporciona un modelo funcional y te permite iterar con rapidez.
  • Rendimiento y escalabilidad: Tunix habilita una infraestructura de entrenamiento altamente eficiente y escalable horizontalmente aprovechando rutas de aprendizaje como un solo controlador en el backend.

Capa de aplicación: inferencia y producción

Un desafío histórico para la adopción de JAX fue la ruta desde la investigación hasta la producción. La pila de IA de JAX ahora proporciona una historia de producción madura de dos frentes que ofrece compatibilidad con el ecosistema y rendimiento de JAX.

Inferencia de LLM de alto rendimiento: la solución de vLLM

vLLM-TPU es la pila de inferencia de alto rendimiento de Google diseñada para ejecutar de manera eficiente modelos de lenguaje grandes (LLM) de PyTorch y JAX en Cloud TPU. Esto se logra integrando de forma nativa el framework popular de código abierto vLLM con el ecosistema de JAX y TPU de Google.

Motivación

La industria está evolucionando rápidamente, con una demanda cada vez mayor de soluciones de inferencia fluidas, de alto rendimiento y fáciles de usar. Por lo general, los desarrolladores se enfrentan a desafíos importantes debido a herramientas complejas y poco coherentes, un rendimiento deficiente y una compatibilidad limitada de los modelos. La pila de vLLM aborda estos problemas, ya que proporciona una plataforma unificada, intuitiva y de alto rendimiento.

Diseño

Esta solución extiende el framework de vLLM en lugar de reinventarlo. vLLM-TPU es un motor de entrega de LLM de código abierto altamente optimizado conocido por su alta capacidad de procesamiento, que se logra con funciones clave como PagedAttention (que administra las caches de KV como memoria virtual para minimizar la fragmentación) y Continuous Batching (que agrega solicitudes de forma dinámica al lote para mejorar la utilización).

vLLM-TPU se basa en esta base y desarrolla componentes centrales para el control de solicitudes, la programación y la administración de la memoria. Presenta un backend basado en JAX que actúa como un puente, ya que traduce el grafo de procesamiento y las operaciones de memoria de vLLM en un código ejecutable en la TPU. Este backend controla las interacciones del dispositivo, la ejecución del modelo JAX y los detalles específicos de la administración de la caché de KV en el hardware de TPU. Incorpora optimizaciones específicas de la TPU, como mecanismos de atención eficientes (por ejemplo, aprovechando los kernels de Pallas de JAX para la atención paginada irregular) y la cuantización, todo ello adaptado a la arquitectura de la TPU.

Principales fortalezas
  • Costo de incorporación y baja cero para los usuarios: Los usuarios pueden adoptar esta solución sin inconvenientes significativos. Desde la perspectiva de la experiencia del usuario, el procesamiento de solicitudes de inferencia en TPU debería ser el mismo que en las GPU. Se comparte la CLI para iniciar el servidor, aceptar mensajes y devolver resultados.
  • Adoptar por completo el ecosistema: Este enfoque utiliza y contribuye a la interfaz y la experiencia del usuario del vLLM, lo que garantiza la compatibilidad y la facilidad de uso.
  • Fungibilidad entre TPU y GPU: La solución funciona de manera eficiente en TPU y GPU, lo que te brinda flexibilidad.
  • Rentable (mejor rendimiento/$): Optimiza el rendimiento para proporcionar la mejor relación entre rendimiento y costo para los modelos populares.

Entrega de JAX: serialización de Orbax y motor de entrega de Neptune

Para los modelos que no son LLM o para los usuarios que desean una canalización completamente nativa de JAX, la biblioteca de serialización Orbax y el sistema del motor de entrega de Neptune (NSE) proporcionan una solución de entrega de alto rendimiento de extremo a extremo.

Motivación

Históricamente, los modelos de JAX solían depender de una ruta indirecta hacia la producción, como estar envueltos en grafos de TensorFlow y, luego, implementarse con TensorFlow Serving. Este enfoque presentó ineficiencias y limitaciones significativas, lo que obligó a los desarrolladores a interactuar con un ecosistema independiente y ralentizó la iteración. Un sistema de entrega nativo de JAX exclusivo es fundamental para la sustentabilidad, la reducción de la complejidad y la optimización del rendimiento.

Diseño

Esta solución consta de dos componentes principales, como se ilustra en el siguiente diagrama.

Diagrama de entrega de JAX

  1. Biblioteca de serialización de Orbax: Proporciona APIs fáciles de usar para serializar modelos de JAX en un nuevo y sólido formato de serialización de Orbax. Este formato está optimizado para la implementación en producción. Representa directamente los cálculos del modelo de JAX con StableHLO, lo que permite que el grafo de procesamiento se represente de forma nativa. También aprovecha TensorStore para almacenar pesos, lo que permite una carga rápida de puntos de control para la entrega.
  2. Motor de entrega de Neptune (NSE): Es el motor de entrega flexible y de alto rendimiento que acompaña a Neptune (por lo general, se implementa como un objeto binario de C++) y que está diseñado para ejecutar de forma nativa modelos de JAX en formato Orbax. El NSE ofrece capacidades esenciales para la producción, como la carga rápida de modelos, la entrega simultánea de alta capacidad de procesamiento por lotes integrados, la compatibilidad con varias versiones de modelos y la entrega de uno o varios hosts (aprovechando PJRT y rutas de aprendizaje). Usa el motor de entrega de Neptune para lo siguiente:
    • Modelos que no son LLM: Es una solución de uso general ideal para cargas de trabajo como sistemas de recomendación, modelos de difusión y otros modelos de IA.
    • LLM pequeños y servicio "con un solo ejemplo": Se diseñó para modelos no autorregresivos o más pequeños que se publican de forma "unaria", en donde toda la salida se genera en una sola pasada sin necesidad de una administración de estado compleja, como una caché de KV.

En resumen, el motor de entrega de Neptune cubre la brecha en la entrega de la amplia variedad de modelos que no son de lenguaje autorregresivo grandes, y proporciona una solución nativa de TPU de alto rendimiento para el ecosistema de AA más amplio.

Principales fortalezas
  • Entrega nativa de JAX: La solución se compila de forma nativa para JAX, lo que elimina la sobrecarga entre frameworks en la serialización y la entrega de modelos. Esto garantiza una carga rápida del modelo y una ejecución optimizada en CPU, GPU y TPU.
  • Implementación de producción sin esfuerzo: Los modelos serializados proporcionan una ruta de implementación hermética que no se ve afectada por la desviación en las dependencias de Python y habilita las verificaciones de integridad del modelo en el entorno de ejecución. Esto proporciona una ruta intuitiva y sin interrupciones para la producción de modelos de JAX.
  • Experiencia del desarrollador mejorada: Quitando la necesidad de envolver el framework de forma engorrosa, esta solución reduce mucho las dependencias y la complejidad del sistema, lo que acelera la iteración para los desarrolladores de JAX.

Análisis y generación de perfiles en todo el sistema

XProf: Generación de perfiles de rendimiento profunda y con integración de hardware

XProf es una herramienta de generación de perfiles y análisis del rendimiento que proporciona una visibilidad detallada de varios aspectos de la ejecución de la carga de trabajo de AA, lo que te permite depurar y optimizar el rendimiento. Está profundamente integrada en los ecosistemas de JAX y TPU.

Motivación

Por un lado, las cargas de trabajo de AA son cada vez más complicadas. Por otro lado, hay una explosión de capacidades de hardware especializadas dirigidas a estas cargas de trabajo. Es fundamental que coincidan de manera eficaz para garantizar el máximo rendimiento y eficiencia, dados los enormes costos de la infraestructura del AA. Esto requiere una visibilidad profunda tanto de la carga de trabajo como del hardware, presentada de una manera que se pueda comprender con rapidez. XProf se destaca en este aspecto.

Diseño

XProf consta de dos componentes principales: la recopilación y el análisis.

  1. Recopilación: XProf captura información de varias fuentes, como anotaciones en tu código de JAX, modelos de costos para operaciones dentro del compilador XLA y funciones de generación de perfiles de hardware creadas específicamente dentro de la TPU. Esta recopilación se puede activar de forma programática o según demanda, y genera un artefacto de evento integral.
  2. Análisis: XProf, luego, procesa los datos recopilados y crea un conjunto de visualizaciones potentes a las que se accede con un navegador.
Principales fortalezas

El verdadero poder de XProf proviene de su profunda integración con la pila completa, lo que proporciona una amplitud y profundidad de análisis que son un beneficio tangible del ecosistema JAX/TPU codiseñado.

  • Diseñado en conjunto con la TPU: XProf aprovecha las funciones de hardware diseñadas específicamente para la recopilación de perfiles sin problemas, lo que permite una sobrecarga de recopilación de menos del 1%. Esto permite que la creación de perfiles sea una parte iterativa y ligera del desarrollo.
  • Amplitud y profundidad del análisis: XProf proporciona un análisis profundo en múltiples ejes. Entre sus herramientas, se incluyen las siguientes:
    • Visualizador de seguimiento: Es una vista del cronograma de la operación de ejecución en diferentes unidades de hardware (por ejemplo, TensorCores).
    • Perfil de operaciones de HLO: Desglosa el tiempo total invertido en diferentes categorías de operaciones.
    • Vista de recuerdos: Detalla las asignaciones de memoria creadas por diferentes operaciones durante el período analizado.
    • Análisis de Roofline: Te ayuda a identificar si las operaciones específicas están limitadas por la capacidad de procesamiento o la memoria, y qué tan lejos están de las capacidades máximas del hardware.
    • Visualizador de grafos: Proporciona una vista del grafo de HLO completo que ejecuta el hardware.

Una perspectiva comparativa: la pila de JAX/TPU como una opción atractiva

El panorama moderno del aprendizaje automático ofrece muchas cadenas de herramientas excelentes y consolidadas. La pila de IA de JAX presenta un conjunto único y atractivo de ventajas para los desarrolladores que se centran en el AA a gran escala y de alto rendimiento, que se derivan directamente de su diseño modular y su profundo codiseño de hardware.

Si bien muchos frameworks ofrecen una amplia variedad de funciones, la pila de IA de JAX proporciona diferenciadores específicos y potentes en áreas clave del ciclo de vida del desarrollo:

  • Una experiencia del desarrollador más simple y potente: El paradigma de transformación de gradiente encadenable de Optax permite estrategias de optimización más potentes y flexibles que se declaran una vez, en lugar de administrarse de forma imperativa en el bucle de entrenamiento. A nivel del sistema, la interfaz de controlador único más simple de las rutas de aprendizaje abstrae la complejidad del entrenamiento de porciones múltiples, lo que representa una simplificación importante para los investigadores.
  • Diseñado para una resiliencia a gran escala: La pila de JAX está diseñada para el entrenamiento a una escala extrema. Orbax proporciona funciones de "resiliencia de entrenamiento a gran escala", como la creación de puntos de control de emergencia y de varios niveles. Esto se complementa con Grain, que ofrece compatibilidad total con la reproducibilidad con aleatorizaciones globales determinísticas y cargadores de datos con puntos de control. La capacidad de crear puntos de control de forma atómica para el estado de la canalización de datos (Grain) con el estado del modelo (Orbax) es fundamental para garantizar la reproducibilidad en los trabajos de larga duración.
  • Un ecosistema completo de extremo a extremo: La pila proporciona una solución cohesiva de extremo a extremo. Los desarrolladores pueden usar MaxText como referencia del SOTA para el entrenamiento, Tunix para la alineación, y seguir una ruta de producción clara de doble vía con vLLM-TPU (para la compatibilidad con vLLM) y NSE (para el rendimiento de JAX).

Si bien muchas pilas son similares desde el punto de vista del software de alto nivel, el factor decisivo suele ser el rendimiento o el TCO, y es allí donde el diseño conjunto de JAX y TPU proporciona una ventaja distintiva. Este beneficio de rendimiento o TCO es el resultado directo de la integración vertical en el software y el hardware de TPU. La capacidad del compilador XLA para fusionar operaciones específicamente para la arquitectura de TPU o la del generador de perfiles XProf para usar hooks de hardware para la generación de perfiles con una sobrecarga inferior al 1% son beneficios tangibles de esta integración profunda.

Para las organizaciones que adoptan esta pila, la naturaleza completa de la pila de IA de JAX minimiza el costo de la migración. Para los clientes que emplean arquitecturas de modelos abiertos populares, el cambio de otros frameworks a MaxText suele ser una cuestión de establecer archivos de configuración. Además, la capacidad de la pila para transferir formatos de puntos de control populares, como safetensors, permite migrar los puntos de control existentes sin necesidad de volver a entrenar el modelo, lo que resulta costoso.

En la siguiente tabla, se proporciona una asignación de los componentes que ofrece la pila de IA de JAX y sus equivalentes en otros frameworks o bibliotecas.

Función JAX Alternativas/equivalentes en otros frameworks5
Compilador/entorno de ejecución XLA Inductor, ansioso
Entrenamiento con varios Pods Rutas de aprendizaje Estrategias de iluminación con Torch, Ray Train y Monarch (nuevo).
Framework principal JAX PyTorch
Formación de modelos Modelos de Flax y Max* torch.nn.*, TransformerEngine de NVidia, Transformers de Hugging Face
Optimizadores y pérdidas Optax torch.optim.*, torch.nn.*Loss
Cargadores de datos Grain Ray Data, cargadores de datos de Hugging Face
Punto de control Orbax Verificación de puntos de control distribuida de PyTorch y de NeMo
Cuantización Qwix TorchAO, bitsandbytes
Autorización de kernels y sus implementaciones conocidas Pallas/Tokamax Triton/Helion, Liger-kernel, TransformerEngine
Entrenamiento/ajuste posterior Tunix VERL, NeMoRL
Generación de perfiles XProf Generador de perfiles de PyTorch, sistemas NSight y NSight Compute
Entrenamiento de modelos de base MaxText y MaxDiffusion NeMo-Megatron, DeepSpeed y TorchTitan
Inferencia de LLM vLLM SGLang
Inferencia sin LLM NSE Triton Inference Server, RayServe

5 Algunos de los equivalentes aquí no siempre son comparaciones verdaderas porque otros frameworks trazan los límites de la API de manera diferente en comparación con JAX. La lista de equivalentes no es exhaustiva y aparecen bibliotecas nuevas con frecuencia.

Conclusión: Una plataforma duradera y lista para producción para el futuro de la IA

Los datos proporcionados en la tabla anterior ilustran una conclusión evidente: estas pilas tienen sus propias fortalezas y debilidades en una pequeña cantidad de áreas; sin embargo, en general, son muy similares desde el punto de vista del software. Ambas pilas proporcionan soluciones listas para usar para el posentrenamiento, la adaptación posterior al entrenamiento y la implementación de modelos de base.

La pila de IA de JAX ofrece una solución atractiva y sólida para entrenar y para implementar modelos de AA a cualquier escala. Aprovecha la profunda integración vertical en el software y el hardware de TPU para ofrecer un rendimiento líder en su clase y un costo total de propiedad.

Basarse en sistemas internos probados de forma rigurosa le permitió a la pila evolucionar para proporcionar confiabilidad y escalabilidad inherentes. Esto les permite a los usuarios desarrollar y, luego, implementar con confianza hasta los modelos más grandes. Su diseño modular y componible, basado en la filosofía de la pila de IA de JAX, otorga a los usuarios una libertad y un control sin precedentes, lo que les permite adaptar la pila a sus necesidades específicas sin las restricciones de un framework monolítico.

La pila de IA de JAX proporciona una base duradera para que los usuarios compilen y lleven rápidamente a producción la investigación de vanguardia. Todo esto lo se logra gracias a XLA y las rutas de aprendizaje, que proporcionan una base escalable y tolerante a errores, JAX, que proporciona una biblioteca numérica expresiva y de alto rendimiento, potentes bibliotecas de desarrollo principales, como Flax, Optax, Grain y Orbax, herramientas de rendimiento avanzadas, como Pallas, Tokamax y Qwix, y una capa de producción y aplicación sólida en MaxText, vLLM y NSE.