Criar IA de produção em Cloud TPUs com JAX
A pilha de IA JAX expande o núcleo numérico JAX com uma coleção de bibliotecas compostas apoiadas pela Google, transformando-o numa plataforma de código aberto robusta, ponto a ponto, para aprendizagem automática em escalas extremas. Como tal, a pilha de IA do JAX consiste num ecossistema abrangente e robusto que aborda todo o ciclo de vida da AA:
Base à escala industrial: a pilha de IA JAX foi arquitetada para uma escala massiva, tirando partido dos ML Pathways para orquestrar a preparação em dezenas de milhares de chips e do Orbax para a criação de pontos de verificação assíncronos resilientes e de elevado débito, o que permite a preparação de modelos de última geração de nível de produção.
Conjunto de ferramentas completo e pronto para produção: a pilha de IA JAX oferece um conjunto abrangente de bibliotecas para todo o processo de desenvolvimento: Flax para a criação flexível de modelos, Optax para estratégias de otimização compostas e Grain para os pipelines de dados determinísticos essenciais para execuções reproduzíveis em grande escala.
Desempenho especializado de pico: para alcançar a utilização máxima do hardware, a pilha de IA JAX oferece bibliotecas especializadas, incluindo Tokamax para kernels personalizados de vanguarda, Qwix para quantização não intrusiva que aumenta a velocidade de preparação e inferência, e XProf para criação de perfis de desempenho profundos e integrados no hardware.
Caminho completo para a produção: a pilha de IA JAX oferece uma transição perfeita da investigação à implementação. Isto inclui o MaxText como referência escalável para a preparação de modelos de base, o Tunix para a aprendizagem por reforço (AR) e o alinhamento de vanguarda, e uma solução de inferência unificada com a integração de TPU vLLM e o tempo de execução de serviço JAX.
A filosofia da pilha de IA do JAX é a de componentes fracamente acoplados, cada um dos quais faz uma coisa bem. Em vez de ser uma framework de ML monolítica, o JAX em si tem um âmbito restrito e foca-se em operações de matriz eficientes e transformações de programas. O ecossistema baseia-se nesta estrutura essencial para oferecer uma vasta gama de funcionalidades relacionadas com a preparação de modelos de ML e outros tipos de cargas de trabalho, como a computação científica.
Este sistema de componentes pouco acoplados permite-lhe selecionar e combinar bibliotecas da melhor forma para se adequar aos seus requisitos. Do ponto de vista da engenharia de software, esta arquitetura também permite atualizar a funcionalidade que seria tradicionalmente considerada componentes essenciais da framework (por exemplo, pipelines de dados e checkpointing) de forma iterativa sem o risco de desestabilizar a framework essencial ou ficar presa em ciclos de lançamento. Uma vez que a maioria das funcionalidades é implementada em bibliotecas em vez de alterações a uma estrutura monolítica, isto torna a biblioteca numérica principal mais duradoura e adaptável a mudanças futuras no panorama tecnológico.
As secções seguintes oferecem uma vista geral técnica da pilha de IA JAX, das respetivas principais funcionalidades, das decisões de design subjacentes e da forma como se combinam para criar uma plataforma duradoura para cargas de trabalho de ML modernas.
A pilha de IA JAX e outros componentes do ecossistema
| Componente | Função / descrição |
|---|---|
| Núcleo e componentes da pilha de IA JAX1 | |
| JAX | Cálculo de matrizes orientado por aceleradores e transformação de programas (JIT, grad, vmap, pmap). |
| Flax | Biblioteca de criação de redes neurais flexível para a criação e modificação intuitivas de modelos. |
| Optax | Uma biblioteca de transformações de processamento e otimização de gradientes compostas. |
| Orbax | Biblioteca de pontos de verificação distribuídos "any-scale" para resiliência de treino em grande escala. |
| Grão | Uma biblioteca de data pipelines de entrada escalável, determinística e com pontos de verificação. |
| JAX AI stack - Infrastructure | |
| XLA | Compilador de aprendizagem automática de código aberto para TPUs, CPUs e GPUs. |
| Pathways | Tempo de execução distribuído para orquestrar a computação em dezenas de milhares de chips. |
| Coleção de IA JAX - Adv. Programação | |
| Pallas | Uma extensão JAX para escrever kernels personalizados de baixo nível e alto desempenho implementados em Python. |
| Tokamax | Uma biblioteca organizada de kernels personalizados de alto desempenho e de última geração (por exemplo, Attention). |
| Qwix | Uma biblioteca abrangente e não intrusiva para a quantização (PTQ, QAT e QLoRA). |
| JAX AI stack – Aplicação | |
| MaxText / MaxDiffusion | Estruturas de referência emblemáticas e escaláveis para preparar modelos de base (por exemplo, LLM e Diffusion). |
| Tunix | Uma estrutura para o alinhamento e o pós-treino de vanguarda (ARFH e ODP). |
| vLLM | Uma solução de inferência de LLM de alto desempenho que usa a integração incorporada da framework vLLM. |
| XProf | Um perfilador profundo integrado no hardware para análise do desempenho ao nível do sistema. |
1Incluído no pacote Python.jax-ai-stack
Figura 1: a pilha de IA JAX e os componentes do ecossistema

O imperativo arquitetónico: desempenho além das estruturas
À medida que as arquiteturas de modelos convergem, por exemplo, em transformadores multimodais de mistura de especialistas (MoE), a procura do desempenho máximo está a levar à emergência de megakernels. Um megakernel é efetivamente a passagem direta completa (ou uma grande parte) de um modelo específico, codificado manualmente através de uma API de nível inferior, como o CUDA SDK em GPUs NVIDIA. Esta abordagem alcança a máxima utilização do hardware através da sobreposição agressiva de computação, memória e comunicação. O trabalho recente da comunidade de investigação demonstrou que esta abordagem pode gerar ganhos significativos de débito, mais de 22% em alguns casos, para a inferência em GPUs. Esta tendência não se limita à inferência. Os dados sugerem que alguns esforços de preparação em grande escala envolveram o controlo de hardware de baixo nível para alcançar ganhos de eficiência substanciais.
Se esta tendência se acelerar, todas as frameworks de nível superior, tal como existem atualmente, correm o risco de se tornarem menos relevantes, uma vez que o acesso de baixo nível ao hardware é o que, em última análise, importa para o desempenho em arquiteturas estáveis e maduras. Isto representa um desafio para todas as stacks de ML modernas: como fornecer controlo de hardware ao nível de especialista sem sacrificar a produtividade e a flexibilidade de uma estrutura de alto nível.
Para que as TPUs ofereçam um caminho claro para este nível de desempenho, o ecossistema tem de expor uma camada de API mais próxima do hardware, o que permite o desenvolvimento destes núcleos altamente especializados. A pilha JAX foi concebida para resolver este problema, oferecendo um continuum de abstração (consulte a Figura 2), desde as otimizações automatizadas de alto nível do compilador XLA ao controlo manual detalhado da biblioteca de criação de kernels Pallas.
Figura 2: o continuum de abstração do JAX

A coleção de IA JAX principal
A base da pilha de IA JAX consiste em cinco bibliotecas principais que fornecem a base para o desenvolvimento de modelos:
JAX: uma base para transformação de programas de alto desempenho e compósitos
O JAX é uma biblioteca Python para computação de matrizes orientada para aceleradores e transformação de programas, concebida para computação numérica de elevado desempenho e aprendizagem automática em grande escala. Com o seu modelo de programação funcional e API semelhante ao NumPy, o JAX oferece uma base sólida para bibliotecas de nível superior.
Com o seu design baseado no compilador, o JAX promove inerentemente a escalabilidade através da utilização do XLA (consulte a secção XLA) para uma análise, otimização e segmentação de hardware agressivas de todo o programa. A ênfase do JAX na programação funcional (por exemplo, funções puras) torna as transformações de programas essenciais mais tratáveis e, crucialmente, compostas.
Estas transformações essenciais podem ser combinadas para alcançar um elevado desempenho e escalabilidade das cargas de trabalho em função do tamanho do modelo, do tamanho do cluster e dos tipos de hardware:
- jit: compilação just-in-time de funções Python em executáveis XLA otimizados e fundidos.
- grad: diferenciação automática, compatível com o modo direto e inverso, bem como derivadas de ordem superior.
- vmap: vetorização automática, que permite o processamento em lote e o paralelismo de dados sem problemas, sem modificar a lógica da função.
- pmap / shard_map: paralelização automática em vários dispositivos (por exemplo, núcleos de TPU), que formam a base para a preparação distribuída.
A integração perfeita com o modelo GSPMD (SPMD de uso geral) do XLA permite que o JAX paralelize automaticamente os cálculos em grandes TPU Pods com alterações mínimas ao código. Na maioria dos casos, a escalabilidade só requer anotações de divisão em fragmentos de alto nível.
Flax: criação flexível de redes neurais
O Flax simplifica a criação, a depuração e a análise de redes neurais no JAX, oferecendo uma abordagem intuitiva e orientada para objetos à criação de modelos. Embora a API funcional do JAX seja poderosa, oferece uma abstração baseada em camadas mais familiar para os programadores habituados a frameworks como o PyTorch, sem qualquer penalização de desempenho.
Este design simplifica a modificação ou a combinação de componentes do modelo preparado.
As técnicas como LoRA e quantização requerem definições de modelos manipuláveis, que a API NNX do Flax fornece através de uma interface Pythonic. NNX encapsula o estado do modelo, reduzindo a carga cognitiva do utilizador e permitindo a travessia programática e a modificação da hierarquia do modelo.
Principais pontos fortes:
- API intuitiva orientada por objetos: simplifica a criação de modelos e permite exemplos de utilização avançados, como a substituição de submódulos e a inicialização parcial.
- Consistente com o JAX principal: o Flax oferece transformações elevadas totalmente compatíveis com o paradigma funcional do JAX, oferecendo o desempenho total do JAX com maior facilidade de utilização para programadores.
Optax: estratégias de otimização e processamento de gradientes compostas
O Optax é uma biblioteca de processamento e otimização de gradientes para o JAX. Foi concebida para oferecer aos criadores de modelos bases que podem ser recombinadas de formas personalizadas para formar modelos de aprendizagem profunda, entre outras aplicações. Baseia-se nas capacidades da biblioteca JAX principal para fornecer uma biblioteca de funções de perda e otimização de alto desempenho bem testada e técnicas associadas que podem ser usadas para preparar modelos de ML.
Motivação
O cálculo e a minimização das perdas estão no centro do que permite o
treino de modelos de ML. Com o respetivo suporte para diferenciação automática, a biblioteca JAX principal oferece as capacidades numéricas para formar modelos, mas não oferece implementações padrão de otimizadores populares (por exemplo, RMSProp ou Adam) nem perdas (por exemplo, CrossEntropy ou MSE). Embora possa implementar estas funções (e alguns programadores avançados optem por fazê-lo), um erro numa implementação do otimizador introduziria problemas de qualidade do modelo difíceis de diagnosticar. Em vez de o utilizador implementar estas partes críticas, a Optax fornece implementações destes algoritmos que são testadas quanto à correção e ao desempenho.
O campo da teoria da otimização situa-se claramente no domínio da investigação. No entanto, o seu papel central na preparação também a torna uma parte indispensável da preparação de modelos de ML de produção. Uma biblioteca que desempenhe esta função tem de ser suficientemente flexível para se adaptar a iterações de investigação rápidas e suficientemente robusta e com bom desempenho para ser fiável para a preparação de modelos de produção. Também deve fornecer implementações bem testadas de algoritmos de vanguarda que correspondam às equações padrão. A biblioteca Optax, através da sua arquitetura modular componível e ênfase no código legível correto, foi concebida para alcançar este objetivo.
Design
O Optax foi concebido para melhorar a velocidade da investigação e a transição da investigação para a produção, fornecendo implementações legíveis, bem testadas e eficientes de algoritmos essenciais. O Optax tem utilizações além do contexto da aprendizagem profunda. No entanto, neste contexto, pode ser visto como uma coleção de funções de perda, algoritmos de otimização e transformações de gradientes bem conhecidas implementadas de forma puramente funcional, em conformidade com a filosofia do JAX. A coleção de perdas conhecidas e otimizadores permite que os utilizadores comecem a usar a API com facilidade e confiança.
A abordagem modular adotada pela Optax permite encadear vários otimizadores juntamente com outras transformações comuns (por exemplo, restrição de gradiente) e envolvê-los usando técnicas comuns, como MultiStep ou Lookahead, para alcançar estratégias de otimização eficazes com algumas linhas de código. A interface flexível permite-lhe pesquisar novos algoritmos de otimização e usar técnicas de otimização de segunda ordem avançadas, como shampoo ou muon.
# Optax implementation of a RMSProp optimizer with a custom learning rate
# schedule, gradient clipping and gradient accumulation.
optimizer = optax.chain(
optax.clip_by_global_norm(GRADIENT_CLIP_VALUE),
optax.rmsprop(learning_rate=optax.cosine_decay_schedule(init_value=lr,decay_steps=decay)),
optax.apply_every(k=ACCUMULATION_STEPS)
)
# The same thing, in PyTorch
optimizer = optim.RMSprop(model_params, lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TOTAL_STEPS)
for i, (inputs, targets) in enumerate(data_loader):
# ... Training loop body ...
if (i + 1) % ACCUMULATION_STEPS == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VALUE)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
O fragmento de código anterior mostra como configurar um otimizador com uma taxa de aprendizagem personalizada, restrição de gradientes e acumulação de gradientes.
Principais pontos fortes
- Biblioteca robusta: oferece uma biblioteca abrangente de perdas, otimizadores e algoritmos com foco na correção e legibilidade.
- Transformações encadeáveis modulares: esta API flexível permite-lhe criar estratégias de otimização complexas e eficazes de forma declarativa, sem modificar o ciclo de preparação.
- Funcional e escalável: as implementações puramente funcionais integram-se perfeitamente com os mecanismos de paralelização do JAX (por exemplo, pmap), o que lhe permite usar o mesmo código para escalar de um único anfitrião para grandes clusters.
Orbax / TensorStore: criação de pontos de verificação distribuídos em grande escala
O Orbax é uma biblioteca de pontos de verificação para o JAX concebida para qualquer escala, desde um único dispositivo ao treino distribuído em grande escala. O objetivo é unificar as implementações de pontos de verificação fragmentadas e oferecer funcionalidades de desempenho críticas, como pontos de verificação assíncronos e de vários níveis, a um público mais vasto. O Orbax permite a resiliência necessária para tarefas de preparação em grande escala e oferece um formato flexível para a publicação de pontos de verificação.
Ao contrário dos sistemas generalizados de ponto de verificação e restauro que criam instantâneos do estado de todo o sistema, os pontos de verificação de ML com o Orbax persistem seletivamente apenas as informações essenciais para retomar os pesos do modelo de preparação, o estado do otimizador e o estado do carregador de dados. Esta abordagem direcionada minimiza o tempo de inatividade do acelerador. O Orbax consegue isto sobrepondo as operações de I/O com a computação, uma funcionalidade crítica para grandes cargas de trabalho. O tempo em que os aceleradores estão inativos é reduzido à duração da transferência de dados do dispositivo para o anfitrião, o que pode ser ainda mais sobreposto com o passo de preparação seguinte, tornando a criação de pontos de verificação quase gratuita do ponto de vista do desempenho.
Essencialmente, o Orbax usa o TensorStore para uma leitura e escrita eficientes e paralelas de dados de matriz. A API Orbax abstrai esta complexidade, oferecendo uma interface fácil de usar para processar PyTrees, que são a representação padrão dos modelos no JAX.
Principais pontos fortes:
- Adoção generalizada: Com milhões de transferências mensais, o Orbax serve como um meio comum para partilhar artefactos de ML.
- Simplifica as complexidades: o Orbax abstrai as complexidades da criação de pontos de verificação distribuídos, incluindo a poupança assíncrona, a atomicidade e os detalhes do sistema de ficheiros.
- Flexível: embora ofereça APIs para exemplos de utilização comuns, o Orbax permite-lhe personalizar o seu fluxo de trabalho para processar requisitos especializados.
- Com bom desempenho e escalável: as funcionalidades como a criação de pontos de verificação assíncronos, um formato de armazenamento eficiente (OCDBT) e as estratégias de carregamento de dados inteligentes garantem que o Orbax é escalável para execuções de preparação que envolvem dezenas de milhares de nós.
Grain: pipelines de dados de entrada determinísticos e escaláveis
O Grain é uma biblioteca Python para ler e processar dados para preparar e avaliar modelos JAX. É flexível, rápido e determinístico, e suporta funcionalidades avançadas, como a criação de pontos de verificação, que são essenciais para formar com êxito grandes cargas de trabalho. É compatível com formatos de dados e back-ends de armazenamento populares, e também oferece uma API flexível para expandir a compatibilidade com formatos e back-ends específicos do utilizador que não são suportados nativamente. Embora o Grain tenha sido concebido principalmente para funcionar com o JAX, é independente da framework, não requer o JAX para ser executado e também pode ser usado com outras frameworks.
Motivação
Os pipelines de dados formam uma parte crítica da infraestrutura de preparação. Têm de ser flexíveis para que as transformações comuns possam ser expressas de forma eficiente e ter um desempenho suficientemente bom para manter os aceleradores ocupados em todos os momentos. Também têm de ser capazes de acomodar vários formatos de armazenamento e backends. Devido aos tempos de passos mais elevados, a preparação de modelos grandes em grande escala coloca requisitos adicionais no pipeline de dados, além dos que são exigidos pelas cargas de trabalho de preparação normais, principalmente focados no determinismo e na reprodutibilidade2. A biblioteca Grain foi concebida com uma arquitetura flexível que satisfaz estas necessidades.
2Na secção 5.1 do artigo do PaLM, os autores observaram picos de perda muito grandes, apesar de terem a restrição de gradientes ativada. A solução foi remover os lotes de dados ofensivos e reiniciar a preparação a partir de um ponto de verificação antes do pico de perda. Isto só é possível com uma configuração de preparação totalmente determinística e reproduzível.
Design
Ao nível mais elevado, existem duas formas de estruturar um pipeline de entrada: como um cluster separado de trabalhadores de dados ou através da colocação conjunta dos trabalhadores de dados nos anfitriões que acionam os aceleradores. O Grain escolhe a segunda opção por vários motivos.
Os aceleradores são combinados com anfitriões potentes que normalmente ficam inativos durante os passos de preparação, o que os torna uma escolha natural para executar o pipeline de dados de entrada. Esta implementação tem vantagens adicionais: simplifica a sua vista da divisão de dados, fornecendo uma vista consistente da divisão em todas as entradas e cálculos. Pode argumentar-se que colocar o trabalhador de dados no anfitrião do acelerador corre o risco de saturar a CPU do anfitrião. No entanto, isto não impede o descarregamento de transformações com utilização intensiva de computação para outro cluster através de RPCs3.
No que diz respeito à API, com uma implementação pura de Python que suporta vários processos e uma API flexível, o Grain permite-lhe implementar transformações de dados arbitrariamente complexas ao compor fases do pipeline com base em paradigmas de transformação bem compreendidos.
O Grain suporta formatos de dados de acesso aleatório eficientes, como ArrayRecord e Bagz, juntamente com outros formatos de dados populares, como Parquet e TFDS. O Grain inclui suporte para leitura de sistemas de ficheiros locais, bem como leitura do Cloud Storage por predefinição. Além de suportar formatos de armazenamento e backends populares, uma abstração limpa à camada de armazenamento permite-lhe adicionar suporte ou encapsular as suas origens de dados existentes para serem compatíveis com a biblioteca Grain.
3É assim que os pipelines de dados multimodais têm de funcionar. Por exemplo, os tokenizadores de imagens e áudio são modelos que são executados nos seus próprios clusters nos seus próprios aceleradores, e os pipelines de entrada fariam chamadas RPC para converter exemplos de dados em streams de tokens.
Principais pontos fortes
- Introdução de dados determinística: a colocação do trabalhador de dados com o acelerador e a sua associação a uma ordenação aleatória global estável e a iteradores com pontos de verificação permite que o estado do modelo e o estado do pipeline de dados sejam verificados em conjunto num instantâneo consistente através do Orbax, o que melhora o determinismo do processo de preparação.
- APIs flexíveis para ativar transformações de dados avançadas: uma API de transformações Python pura e flexível permite-lhe realizar transformações de dados extensivas no pipeline de processamento de entrada.
- Suporte extensível para vários formatos e backends: uma API de origens de dados extensível suporta formatos de armazenamento e backends populares, e permite-lhe adicionar suporte para novos formatos e backends.
- Interface de depuração avançada: as ferramentas de visualização do pipeline de dados e um modo de depuração permitem-lhe analisar, depurar e otimizar o desempenho dos seus pipelines de dados.
A coleção de IA JAX alargada
Além da base essencial, um ecossistema avançado de bibliotecas especializadas fornece a infraestrutura, as ferramentas avançadas e as soluções da camada de aplicação necessárias para o desenvolvimento de ML completo.
Infraestrutura fundamental: compiladores e tempos de execução
XLA: o motor independente do hardware e centrado no compilador
Motivação
A XLA ou a álgebra linear acelerada é o compilador específico do domínio da Google, que está bem integrado no JAX e suporta dispositivos de hardware TPU, CPU e GPU. O XLA foi concebido para ser um gerador de código independente do hardware que segmenta TPUs, GPUs e CPUs.
O design de compilação primeiro do compilador XLA é uma escolha arquitetónica fundamental que cria uma vantagem duradoura num panorama de investigação em rápida evolução. Em contrapartida, a abordagem predominante centrada no kernel noutros ecossistemas baseia-se em bibliotecas otimizadas manualmente para o desempenho. Embora seja altamente eficaz para arquiteturas de modelos estáveis e bem estabelecidas, cria um obstáculo à inovação. Quando a nova investigação introduz arquiteturas inovadoras, o ecossistema tem de aguardar que sejam escritos e otimizados novos kernels. No entanto, o nosso design centrado no compilador pode, muitas vezes, generalizar-se a novos padrões, oferecendo um caminho de alto desempenho para a investigação de ponta desde o primeiro dia.
Design
O XLA funciona através da compilação Just-In-Time (JIT) dos gráficos de computação que o JAX gera durante o respetivo processo de rastreio (por exemplo, quando uma função é decorada com @jax.jit).
Esta compilação segue um pipeline de várias fases:
- Gráfico de computação do JAX
- Otimizador de alto nível (HLO)
- Otimizador de baixo nível (LLO)
- Código de hardware
- De JAX Graph para HLO: o gráfico de computação JAX é convertido na representação HLO do XLA. Neste nível elevado, são aplicadas otimizações poderosas e independentes do hardware, como a fusão de operadores e a gestão eficiente da memória. O dialeto StableHLO serve como uma interface duradoura e com versões para esta fase.
- Do HLO para o LLO: após as otimizações de alto nível, os backends específicos do hardware assumem o controlo, reduzindo a representação do HLO para um LLO orientado para a máquina.
- Do LLO ao código de hardware: o LLO é finalmente compilado num código de máquina altamente eficiente. Para as TPUs, este código é agrupado como pacotes de palavras de instruções muito longas (VLIW) que são enviados diretamente para o hardware.
Para o escalamento, o design do XLA baseia-se no paralelismo. Emprega algoritmos para usar ao máximo as unidades de multiplicação de matrizes (MXUs) num chip. Entre os chips, a XLA usa SPMD (Single Program Multiple Data), uma técnica de paralelização baseada no compilador que usa um único programa em todos os dispositivos. Este modelo avançado é exposto através das APIs JAX, o que lhe permite gerir o paralelismo de dados, modelos ou pipelines com anotações de divisão de nível superior.
Para padrões de paralelismo mais complexos, também é possível usar vários programas com vários dados (MPMD), e bibliotecas como PartIR:MPMD permitem que os utilizadores do JAX também forneçam anotações MPMD.
Principais pontos fortes
- Compilação: a compilação just-in-time do gráfico de computação permite otimizações ao esquema de memória, à atribuição de buffers e à gestão de memória. As alternativas, como as metodologias baseadas no kernel, transferem esse encargo para o programador. Na maioria dos casos, o XLA pode alcançar um excelente desempenho sem comprometer a velocidade de desenvolvimento.
- Paralelismo: o XLA implementa várias formas de paralelismo com SPMD, e isto é exposto ao nível do JAX. Isto permite-lhe expressar estratégias de divisão, permitindo a experimentação e a escalabilidade de modelos em milhares de chips.
Pathways: um tempo de execução unificado para computação distribuída em grande escala
Pathways oferece abstrações para a preparação e a inferência distribuídas com tolerância a falhas e recuperação integradas, o que permite aos investigadores de ML programar como se estivessem a usar uma máquina única e potente.
Motivação
Para poder preparar e implementar modelos grandes, são necessários centenas a milhares de chips. Estes chips estão distribuídos por vários racks e máquinas anfitriãs. Uma tarefa de preparação é um programa síncrono de grande escala que requer todos estes chips e os respetivos anfitriões a trabalhar em conjunto em cálculos XLA que foram paralelizados (divididos). No caso dos modelos de linguagem (conteúdo extenso), que podem precisar de mais de dezenas de milhares de chips, este serviço tem de ser capaz de abranger vários pods numa estrutura de centro de dados, além de usar estruturas de interconexão entre chips (ICI) e interconexão no chip (OCI) num pod.
Design
O ML Pathways é o sistema que usamos para coordenar cálculos distribuídos em anfitriões e chips de TPU. Foi concebido para escalabilidade e eficiência em centenas de milhares de aceleradores. Para a preparação em grande escala, fornece um único cliente Python para várias tarefas de pods, integração do Megascale XLA, serviço de compilação e Python remoto. Também suporta o paralelismo entre fatias e a tolerância de preempção, o que permite a recuperação automática de preempções de recursos.
Os Pathways incorporam coletivos entre anfitriões otimizados que permitem que os gráficos de computação XLA se estendam para além de um único agrupamento de TPUs. Expande o suporte do XLA para o paralelismo de dados, modelos e pipelines para funcionar em limites de fatias de TPUs usando a rede do centro de dados (DCN) através da integração de um tempo de execução distribuído que gere a comunicação DCN com primitivas de comunicação XLA.
Principais pontos fortes
A arquitetura de controlador único, integrada com o JAX, é uma abstração fundamental. Permite aos investigadores explorar várias estratégias de divisão e paralelismo para a preparação e a implementação, enquanto dimensionam facilmente para dezenas de milhares de chips.
Desenvolvimento avançado: desempenho, dados e eficiência
Pallas: escrever kernels personalizados de alto desempenho no JAX
Embora o JAX seja compilador primeiro, existem situações em que pode querer um controlo detalhado sobre o hardware para alcançar o máximo desempenho. O Pallas é uma extensão do JAX que permite escrever kernels personalizados para GPUs e TPUs. Tem como objetivo
oferecer um controlo preciso sobre o código gerado, combinado com a ergonomia
de alto nível da rastreabilidade do JAX e da API jax.numpy.
O Pallas expõe um modelo de paralelismo baseado em grelhas onde uma função de kernel definida pelo utilizador é iniciada numa grelha multidimensional de grupos de trabalho paralelos. Permite a gestão explícita da hierarquia de memória, permitindo-lhe definir como os tensores são divididos em mosaicos e transferidos entre uma memória mais lenta e maior (por exemplo, HBM) e uma memória no chip mais rápida e menor (por exemplo, VMEM no TPU, memória partilhada na GPU), usando mapas de índices para associar localizações da grelha a blocos de dados específicos. O Pallas pode reduzir a mesma definição do kernel para executar de forma eficiente nas TPUs da Google e em várias GPUs, compilando kernels numa representação intermédia adequada para a arquitetura de destino: Mosaic para TPUs ou usando tecnologias como o Triton para GPUs. Com o Pallas, pode escrever kernels de alto desempenho que especializam blocos como a atenção para alcançar o melhor desempenho do modelo no hardware de destino sem ter de depender de kits de ferramentas específicos do fornecedor.
Tokamax: uma biblioteca organizada de kernels de última geração
Se o Pallas for uma ferramenta para criar kernels, o Tokamax é uma biblioteca de kernels de aceleradores personalizados de vanguarda que suportam TPUs e GPUs. O Tokamax é criado com base no JAX e no Pallas, e permite-lhe usar todo o potencial do seu hardware. Também oferece ferramentas para criar e ajustar automaticamente kernels personalizados.
Motivação
O JAX, com raízes no XLA, é uma framework de compilação em primeiro lugar. No entanto, existe um conjunto restrito de casos em que pode ter de assumir o controlo direto do hardware para alcançar o máximo desempenho4. Os kernels personalizados são essenciais para obter o melhor desempenho de recursos de aceleradores de AA dispendiosos, como TPUs e GPUs. Embora sejam amplamente usadas para permitir a execução com bom desempenho de operadores importantes, como a atenção, a respetiva implementação requer uma compreensão profunda do modelo e da arquitetura de hardware de destino. A Tokamax oferece uma fonte autorizada de kernels organizados, bem testados e de elevado desempenho, juntamente com uma infraestrutura partilhada robusta para o respetivo desenvolvimento, manutenção e gestão do ciclo de vida. Esta biblioteca também pode funcionar como uma implementação de referência para criar e personalizar conforme necessário. Isto permite-lhe concentrar-se nos seus esforços de modelagem sem ter de se preocupar com a infraestrutura.
4Este é um paradigma bem estabelecido e tem precedentes no mundo da CPU, onde o código compilado constitui a maior parte do programa com os programadores a recorrer a funções intrínsecas ou a assemblagem inline para otimizar secções críticas para o desempenho.
Design
Para qualquer kernel, o Tokamax fornece uma API comum que pode ser suportada por várias implementações. Por exemplo, os núcleos da TPU podem ser implementados através da redução padrão da XLA ou explicitamente com a Pallas/Mosaic-TPU. Os kernels da GPU podem ser implementados através da redução padrão da XLA, com a Mosaic-GPU ou a Triton. Por predefinição, a API Tokamax escolhe a implementação mais conhecida para uma determinada configuração, determinada pelos resultados em cache de execuções periódicas de testes de referência e ajuste automático, embora possa escolher implementações específicas, se necessário. As novas implementações podem ser adicionadas ao longo do tempo para explorar melhor funcionalidades específicas em novas gerações de hardware e, assim, melhorar ainda mais o desempenho.
Um componente fundamental da biblioteca Tokamax, além dos próprios núcleos, é a infraestrutura de apoio que lhe permite escrever núcleos personalizados. Por exemplo, a infraestrutura de ajuste automático permite-lhe definir um conjunto de parâmetros configuráveis (por exemplo, tamanhos de mosaicos) nos quais o Tokamax pode fazer uma análise exaustiva para determinar e colocar em cache as melhores definições ajustadas possíveis. As regressões noturnas protegem-no de problemas inesperados de desempenho e numéricos causados por alterações à infraestrutura do compilador subjacente ou a outras dependências.
Principais pontos fortes
- Experiência de programador integrada: uma biblioteca unificada e organizada oferece implementações de bom desempenho conhecidas de kernels importantes, com expressões claras das gerações de hardware suportadas e do desempenho esperado, tanto programaticamente como na documentação. Isto minimiza a fragmentação e a rotatividade.
- Flexibilidade e gestão do ciclo de vida: pode escolher implementações diferentes e até alterá-las ao longo do tempo, se for adequado. Por exemplo, se o compilador XLA melhorar o suporte para determinadas operações e já não precisar de núcleos personalizados, existe um caminho para a descontinuação e a migração.
- Extensibilidade: pode implementar os seus próprios kernels, enquanto tira partido de uma infraestrutura partilhada bem suportada, o que lhe permite focar-se nas capacidades e otimizações de valor acrescentado. As implementações padrão claramente criadas servem como ponto de partida para os utilizadores aprenderem e as expandirem.
Qwix: quantização não intrusiva e abrangente
A Qwix é uma biblioteca de quantização abrangente para a pilha de IA do JAX, que suporta LLMs e outros tipos de modelos em todas as fases, incluindo o treino (Quantization Aware Training [QAT], Quantization Technique [QT], Quantized Low-Rank Adaptation [QLoRA]) e a inferência Post Training Quantization (PTQ), direcionada para os tempos de execução XLA e no dispositivo.
Motivação
As bibliotecas de quantização existentes, particularmente no ecossistema do PyTorch, servem frequentemente propósitos limitados (por exemplo, apenas PTQ ou apenas QLoRA). Este panorama fragmentado obriga a alternar entre ferramentas, o que dificulta a utilização consistente de código e a correspondência numérica precisa entre a preparação e a inferência. Além disso, muitas soluções requerem modificações substanciais do modelo, o que acopla fortemente a lógica do modelo à lógica de quantização.
Design
A filosofia de design da Qwix enfatiza uma solução abrangente e, fundamentalmente, uma integração de modelos não intrusiva. Tem uma arquitetura com um design hierárquico e extensível, criado com base em APIs funcionais reutilizáveis.
Esta integração não intrusiva é alcançada através de um mecanismo de interceção meticulosamente concebido que redireciona as funções JAX para as respetivas contrapartes quantizadas. Isto permite-lhe integrar os seus modelos sem modificações, separando completamente o código de quantização das definições dos modelos.
O exemplo seguinte demonstra a aplicação da quantização w4a4 (uma ponderação de 4 bits, uma ativação de 4 bits) às camadas MLP de um MDG e a quantização w8 (uma ponderação de 8 bits) ao incorporador. Para alterar a receita de quantização, só tem de
atualizar a lista de regras.
fp_model = ModelWithoutQuantization(...)
rules = [
qwix.QuantizationRule(
module_path=r'embedder',
weight_qtype='int8',
),
qwix.QuantizationRule(
module_path=r'layers_\d+/mlp',
weight_qtype='int4',
act_qtype='int4',
tile_size=128,
weight_calibration_method='rms,7',
),
]
quantized_model = qwix.quantize_model(fp_model, qwix.PtqProvider(rules))
Principais pontos fortes
- Solução abrangente: o Qwix é amplamente aplicável em vários cenários de quantização, garantindo uma utilização consistente do código entre a preparação e a inferência.
- Integração de modelos não intrusiva: como mostra o exemplo, pode integrar modelos com uma única linha de código. Isto permite-lhe usar hiperparâmetros em vários esquemas de quantização para encontrar o melhor compromisso entre qualidade e desempenho.
- Federado com outras bibliotecas: o Qwix integra-se perfeitamente com a pilha de IA JAX. Por exemplo, o Tokamax adapta-se automaticamente para usar versões quantizadas de kernels, sem código de utilizador adicional, quando o modelo é quantizado com o Qwix.
- Adequado para investigação: as APIs fundamentais e a arquitetura extensível do Qwix permitem aos investigadores explorar novos algoritmos e facilitam as comparações simples com ferramentas de avaliação e testes de referência integradas.
A camada de aplicação: formação e alinhamento
Preparação de modelos de base: MaxText e MaxDiffusion
O MaxText e o MaxDiffusion são, respetivamente, as estruturas de preparação de modelos de difusão e GML emblemáticos da Google. Estes repositórios contêm uma seleção de implementações altamente otimizadas de modelos de pesos abertos populares. Têm uma dupla finalidade: funcionam como uma base de código de preparação de modelos pronta a usar e como uma referência que os criadores de modelos de base podem usar para desenvolver.
Motivação
Existe um rápido crescimento do interesse em toda a indústria na preparação de modelos de IA gen. A popularidade dos modelos abertos acelerou esta tendência, oferecendo arquiteturas comprovadas. A preparação e a adaptação destes modelos requerem um elevado desempenho, eficiência, escalabilidade para um grande número de chips e código claro e compreensível. O MaxText e o MaxDiffusion são soluções abrangentes que podem ser usadas em TPUs ou GPUs e foram concebidas para satisfazer estas necessidades.
Design
O MaxText e o MaxDiffusion são bases de código de modelos fundamentais concebidas tendo em conta a legibilidade e o desempenho. Estão estruturadas com componentes reutilizáveis e bem testados: definições de modelos que usam kernels personalizados (como o Tokamax) para um desempenho máximo, um arnês de preparação para orquestração e monitorização, e um sistema de configuração avançado que lhe permite controlar detalhes como a divisão e a quantização (através do Qwix) através de uma interface intuitiva. As funcionalidades de fiabilidade avançada, como a verificação de vários níveis, estão incorporadas para garantir um bom rendimento sustentado.
O MaxText e o MaxDiffusion usam as melhores bibliotecas JAX da classe Qwix, Tunix, Orbax e Optax para oferecer capacidades essenciais. Estas bibliotecas oferecem uma infraestrutura robusta e escalável, o que reduz os custos de desenvolvimento e permite que se concentre na tarefa de modelagem. Para a inferência, o código do modelo é partilhado para permitir uma publicação eficiente e escalável.
Principais pontos fortes
- Desempenho por design: com uma infraestrutura de preparação configurada para um "goodput" (débito útil) elevado e implementações de modelos otimizadas para uma MFU (utilização de flops do modelo) elevada, o MaxText e o MaxDiffusion oferecem um elevado desempenho em grande escala de imediato.
- Criado para a escala: tirando partido do poder da pilha de IA JAX (especialmente Pathways), estas estruturas permitem-lhe escalar facilmente de dezenas de chips para dezenas de milhares de chips.
- Base sólida para criadores de modelos de base: as implementações legíveis e de alta qualidade servem como um ponto de partida sólido para os programadores usarem como uma solução completa ou como uma implementação de referência para as suas próprias personalizações.
Após a formação e o alinhamento: a estrutura Tunix
A Tunix oferece algoritmos de aprendizagem por reforço (AR) de código aberto de vanguarda, juntamente com uma estrutura e uma infraestrutura robustas, o que proporciona um caminho simplificado para os programadores experimentarem técnicas de pós-preparação de MDIs, incluindo o ajuste fino supervisionado (AFS) e o alinhamento através do JAX e das TPUs.
Motivação
A pós-preparação é um passo fundamental para desbloquear o verdadeiro poder dos GMLs. A fase de aprendizagem por reforço (AR) é particularmente crucial para desenvolver capacidades de alinhamento e raciocínio. O desenvolvimento de código aberto nesta área baseou-se quase exclusivamente no PyTorch e nas GPUs, o que deixou uma lacuna fundamental para as soluções JAX e TPU. O Tunix (Tune-in-JAX) é uma biblioteca nativa do JAX de alto desempenho concebida para colmatar esta lacuna.
Design

Do ponto de vista da estrutura, o Tunix permite uma configuração de vanguarda que separa claramente os algoritmos de RL da infraestrutura. Oferece uma API simples, semelhante à de um cliente, que oculta a complexidade da infraestrutura de RL, permitindo-lhe desenvolver novos algoritmos. A Tunix oferece soluções prontas a usar para algoritmos populares, incluindo a otimização de políticas proximais (PPO), a otimização de preferências diretas (DPO) e outros.
No lado da infraestrutura, a Tunix tem integração com os Pathways, o que permite uma arquitetura de controlador único que torna o treino de RL com vários nós acessível. No que diz respeito à preparação, o Tunix suporta nativamente a preparação eficiente em termos de parâmetros (por exemplo, LoRA) e tira partido da divisão em fragmentos do JAX e da XLA (paralelização geral e escalável para o gráfico de computação de AA [GSPMD]) para gerar um gráfico de computação com bom desempenho. É compatível com modelos de código aberto populares, como o Gemma e o Llama, de origem.
Principais pontos fortes
- Simplicidade: oferece uma API de alto nível semelhante a um cliente que abstrai as complexidades da infraestrutura distribuída subjacente.
- Eficiência do programador: o Tunix acelera o ciclo de vida de I&D com algoritmos e "receitas" incorporados, o que lhe dá um modelo funcional e permite iterar rapidamente.
- Desempenho e escalabilidade: o Tunix permite uma infraestrutura de preparação altamente eficiente e escalável horizontalmente tirando partido dos caminhos como um único controlador no back-end.
A camada de aplicação: produção e inferência
Um desafio histórico para a adoção do JAX tem sido o caminho da investigação para a produção. A pilha de IA JAX oferece agora uma história de produção madura de duas vertentes que oferece compatibilidade com o ecossistema e desempenho do JAX.
Inferência de MDI/CE de alto desempenho: a solução vLLM
O vLLM-TPU é a pilha de inferência de alto desempenho da Google concebida para executar modelos de linguagem (conteúdo extenso) (MDL/CE) do PyTorch e JAX de forma eficiente nas TPUs na nuvem. Consegue isto integrando nativamente o popular framework vLLM de código aberto com o ecossistema JAX e TPU da Google.
Motivação
A indústria está a evoluir rapidamente, com uma procura crescente de soluções de inferência integradas, de alto desempenho e fáceis de usar. Os programadores enfrentam frequentemente desafios significativos devido a ferramentas complexas e inconsistentes, desempenho inferior ao esperado e compatibilidade limitada dos modelos. A pilha vLLM resolve estes problemas através de uma plataforma unificada, com bom desempenho e intuitiva.
Design
Esta solução expande a framework vLLM, em vez de a reinventar. O vLLM-TPU é um motor de fornecimento de GMLs de código aberto altamente otimizado conhecido pelo seu elevado débito, alcançado através de funcionalidades importantes, como o PagedAttention (que gere caches KV como memória virtual para minimizar a fragmentação) e o Continuous Batching (que adiciona dinamicamente pedidos ao lote para melhorar a utilização).
O vLLM-TPU baseia-se nesta base e desenvolve componentes essenciais para o processamento de pedidos, o agendamento e a gestão de memória. Apresenta um backend baseado em JAX que funciona como uma ponte, traduzindo o gráfico computacional e as operações de memória do vLLM em código executável por TPUs. Este back-end processa as interações com o dispositivo, a execução do modelo JAX e os detalhes da gestão da cache KV no hardware da TPU. Incorpora otimizações específicas da TPU, como mecanismos de atenção eficientes (por exemplo, tirar partido dos núcleos JAX Pallas para atenção paginada irregular) e quantização, tudo adaptado à arquitetura da TPU.
Principais pontos fortes
- Custo de integração/desativação zero para os utilizadores: os utilizadores podem adotar esta solução sem atrito significativo. Do ponto de vista da experiência do utilizador, o processamento de pedidos de inferência em TPUs deve ser igual ao processamento em GPUs. A CLI para iniciar o servidor, aceitar comandos e devolver resultados é partilhada.
- Abrace totalmente o ecossistema: esta abordagem usa e contribui para a interface e a experiência do utilizador do vLLM, garantindo a compatibilidade e a facilidade de utilização.
- Fungibilidade entre TPUs e GPUs: a solução funciona de forma eficiente em TPUs e GPUs, o que lhe dá flexibilidade.
- Rentável (melhor desempenho/custo): otimiza o desempenho para oferecer a melhor relação desempenho/custo para modelos populares.
Publicação do JAX: serialização do Orbax e motor de publicação do Neptune
Para modelos que não sejam LLMs ou para utilizadores que desejem um pipeline totalmente nativo do JAX, a biblioteca de serialização Orbax e o sistema do motor de publicação Neptune (NSE) oferecem uma solução de publicação de alto desempenho integral.
Motivação
Historicamente, os modelos JAX dependiam frequentemente de um caminho indireto para a produção, como serem incluídos em gráficos do TensorFlow e implementados através do TensorFlow Serving. Esta abordagem introduziu limitações e ineficiências significativas, o que obrigou os programadores a interagir com um ecossistema separado e a abrandar a iteração. Um sistema de fornecimento nativo do JAX dedicado é fundamental para a sustentabilidade, a redução da complexidade e o desempenho otimizado.
Design
Esta solução consiste em dois componentes principais, conforme ilustrado no diagrama seguinte.

- Biblioteca de serialização Orbax: fornece APIs fáceis de usar para serializar modelos JAX num novo formato de serialização Orbax robusto. Este formato está otimizado para a implementação em produção. Representa diretamente os cálculos do modelo JAX com a utilização do StableHLO, o que permite que o gráfico de cálculo seja representado nativamente. Também tira partido do TensorStore para armazenar ponderações, o que permite um carregamento rápido de pontos de verificação para a publicação.
- Neptune Serving Engine (NSE): este é o motor de publicação flexível de alto desempenho que acompanha o Neptune (normalmente implementado como um binário C++) concebido para executar nativamente modelos JAX no formato Orbax. O NSE oferece capacidades essenciais para a produção, como carregamento rápido de modelos, serviço simultâneo de elevado débito com processamento em lote integrado, suporte para várias versões de modelos e serviço de anfitrião único e múltiplo (tirando partido do PJRT e dos Pathways). Use o Neptune
Serving Engine para:
- Modelos não baseados em GMLs: é uma solução de uso geral ideal para cargas de trabalho como sistemas de recomendação, modelos de difusão e outros modelos de IA.
- Pequenos MDIs e publicação "única": foi concebido para modelos não autorregressivos ou modelos mais pequenos que são publicados de forma "unária", em que a saída completa é gerada numa única passagem sem necessidade de gestão de estado complexa, como uma cache KV.
Em resumo, o Neptune Serving Engine preenche a lacuna para publicar a grande variedade de modelos que não são modelos de linguagem autorregressivos grandes, oferecendo uma solução nativa de TPU de alto desempenho para o ecossistema de AA mais amplo.
Principais pontos fortes
- Publicação nativa do JAX: a solução é criada nativamente para o JAX, o que elimina a sobrecarga entre frameworks na serialização e publicação de modelos. Isto garante um carregamento rápido do modelo e uma execução otimizada em CPUs, GPUs e TPUs.
- Implementação de produção sem esforço: os modelos serializados oferecem um caminho de implementação hermético que não é afetado pela deriva nas dependências do Python e permite verificações de integridade do modelo em tempo de execução. Isto oferece um caminho simples e intuitivo para a produção de modelos JAX.
- Experiência de programador melhorada: ao eliminar a necessidade de uma união de frameworks complexa, esta solução reduz significativamente as dependências e a complexidade do sistema, acelerando a iteração para os programadores do JAX.
Análise e criação de perfis ao nível do sistema
XProf: criação de perfis de desempenho detalhada e integrada no hardware
O XProf é uma ferramenta de criação de perfis e análise de desempenho que oferece visibilidade detalhada em vários aspetos da execução da carga de trabalho de ML, permitindo-lhe depurar e otimizar o desempenho. Está profundamente integrado nos ecossistemas do JAX e da TPU.
Motivação
Por um lado, as cargas de trabalho de AA estão a tornar-se cada vez mais complicadas. Por outro lado, existe uma explosão de capacidades de hardware especializadas que visam estas cargas de trabalho. A correspondência eficaz dos dois para garantir o máximo desempenho e eficiência é fundamental, dados os enormes custos da infraestrutura de ML. Isto requer uma visibilidade profunda da carga de trabalho e do hardware, apresentada de forma rapidamente consumível. O XProf destaca-se nesta área.
Design
O XProf é composto por dois componentes principais: recolha e análise.
- Recolha: o XProf capta informações de várias origens: anotações no seu código JAX, modelos de custos para operações no compilador XLA e funcionalidades de criação de perfis de hardware criadas especificamente no TPU. Esta recolha pode ser acionada programaticamente ou a pedido, gerando um artefacto de evento abrangente.
- Análise: o XProf pós-processa os dados recolhidos e cria um conjunto de visualizações eficazes, acedidas com um navegador.
Principais pontos fortes
O verdadeiro poder do XProf reside na sua integração profunda com a pilha completa, o que oferece uma amplitude e uma profundidade de análise que são uma vantagem tangível do ecossistema JAX/TPU concebido em conjunto.
- Concebido em conjunto com a TPU: o XProf explora as funcionalidades de hardware especificamente concebidas para uma recolha de perfis integrada, o que permite uma sobrecarga de recolha de menos de 1%. Isto permite que a criação de perfis seja uma parte simples e iterativa do desenvolvimento.
- Amplitude e profundidade da análise: o XProf gera uma análise detalhada em vários eixos. As respetivas ferramentas incluem:
- Trace Viewer: uma vista de cronologia de operações da execução em diferentes unidades de hardware (por exemplo, TensorCores).
- Perfil de operações de HLO: divide o tempo total gasto em diferentes categorias de operações.
- Visualizador de memória: detalha as atribuições de memória por diferentes operações durante a janela com perfil.
- Análise de limite máximo: ajuda a identificar se operações específicas estão limitadas pela computação ou pela memória e a que distância estão das capacidades máximas do hardware.
- Visualizador de gráficos: oferece uma vista do gráfico HLO completo executado pelo hardware.
Uma perspetiva comparativa: a pilha JAX/TPU como uma escolha apelativa
O panorama moderno da aprendizagem automática oferece muitas cadeias de ferramentas excelentes e desenvolvidas. A pilha de IA JAX apresenta um conjunto único e apelativo de vantagens para os programadores focados em ML de grande escala e alto desempenho, decorrentes diretamente do seu design modular e da profunda conceção conjunta de hardware.
Embora muitas frameworks ofereçam uma grande variedade de funcionalidades, a JAX AI Stack oferece diferenciadores específicos e poderosos em áreas importantes do ciclo de vida de desenvolvimento:
- Uma experiência de programador mais simples e poderosa: o paradigma de transformação de gradiente encadeável do Optax permite estratégias de otimização mais poderosas e flexíveis que são declaradas uma vez, em vez de serem geridas imperativamente no ciclo de preparação. Ao nível do sistema, a interface de controlador único mais simples do Pathways abstrai a complexidade da formação de várias fatias, o que representa uma simplificação significativa para os investigadores.
- Concebido para resiliência à escala de herói: a pilha JAX foi concebida para o treino à escala extrema. O Orbax oferece funcionalidades de "resiliência de treino à escala de herói" , como pontos de verificação de emergência e de vários níveis. Isto é complementado pelo Grain, que oferece suporte total para a reprodutibilidade com misturas globais determinísticas e carregadores de dados com pontos de verificação. A capacidade de criar pontos de verificação atómicos do estado do pipeline de dados (Grain) com o estado do modelo (Orbax) é uma capacidade crítica para garantir a reprodutibilidade em trabalhos de longa duração.
- Um ecossistema completo e abrangente: a pilha oferece uma solução coesa e abrangente. Os programadores podem usar o MaxText como referência SOTA para a preparação, o Tunix para o alinhamento e seguir um caminho duplo claro para a produção com o vLLM-TPU (para compatibilidade com vLLM) e o NSE (para o desempenho do JAX).
Embora muitas pilhas sejam semelhantes do ponto de vista do software de alto nível, o fator decisivo resume-se frequentemente ao desempenho/CCT, que é onde o design conjunto do JAX e das UTPs oferece uma vantagem distinta. Esta vantagem de desempenho/CCT é um resultado direto da integração vertical no software e no hardware da TPU. A capacidade do compilador XLA de fundir operações especificamente para a arquitetura de TPU ou do criador de perfis XProf de usar hooks de hardware para a criação de perfis com uma sobrecarga inferior a 1% são vantagens tangíveis desta integração profunda.
Para as organizações que adotam esta pilha, a natureza totalmente funcional da pilha de IA do JAX minimiza o custo da migração. Para os clientes que usam arquiteturas de modelos abertos populares, a mudança de outras frameworks para o MaxText é frequentemente uma questão de configurar ficheiros de configuração. Além disso, a capacidade da pilha de carregar formatos de pontos de verificação populares, como safetensors, permite a migração dos pontos de verificação existentes sem necessidade de uma nova preparação dispendiosa.
A tabela seguinte fornece um mapeamento dos componentes fornecidos pela pilha de IA JAX e os respetivos equivalentes noutras frameworks ou bibliotecas.
| Função | JAX | Alternativas/equivalentes noutras estruturas5 |
| Compilador / tempo de execução | XLA | Inductor, eager |
| Formação multipod | Pathways | Estratégias de iluminação Torch, Ray Train e Monarch (novo). |
| Framework principal | JAX | PyTorch |
| Criação de modelos | Modelos Flax e Max* | torch.nn.*,
NVidia TransformerEngine, HuggingFace Transformers
|
| Otimizadores e perdas | Optax | torch.optim.*, torch.nn.*Loss |
| Carregadores de dados | Granulado | Ray Data, HuggingFace dataloaders |
| Criação de pontos de restauro | Orbax | Criação de pontos de verificação distribuídos do PyTorch, Criação de pontos de verificação do NeMo |
| Quantização | Qwix | TorchAO, bitsandbytes |
| Criação de kernels e implementações conhecidas | Pallas / Tokamax | Triton/Helion, Liger-kernel, TransformerEngine |
| Pós-formação / ajuste | Tunix | VERL e NeMoRL |
| Criação de perfis | XProf | PyTorch profiler, NSight systems, NSight Compute |
| Preparação de modelos de base | MaxText, MaxDiffusion | NeMo-Megatron, DeepSpeed, TorchTitan |
| Inferência de MDI/CE | vLLM | SGLang |
| Inferência não LLM | NSE | Triton Inference Server, RayServe |
5Alguns dos equivalentes aqui não são sempre comparações verdadeiras porque outras frameworks definem os limites da API de forma diferente em comparação com o JAX. A lista de equivalentes não é exaustiva e surgem novas bibliotecas com frequência.
Conclusão: uma plataforma duradoura e pronta para produção para o futuro da IA
Os dados fornecidos na tabela anterior ilustram uma conclusão evidente: estas pilhas têm os seus próprios pontos fortes e fracos num pequeno número de áreas, mas, no geral, são muito semelhantes do ponto de vista do software. Ambas as plataformas oferecem soluções prontas a usar para a pré-preparação, a adaptação pós-preparação e a implementação de modelos fundamentais.
A pilha de IA JAX oferece uma solução atraente e robusta para preparar e implementar modelos de ML a qualquer escala. Tira partido da integração vertical profunda no software e no hardware da TPU para oferecer um desempenho líder de mercado e um custo total de propriedade.
Com base em sistemas internos testados em combate, a plataforma evoluiu para oferecer fiabilidade e escalabilidade inerentes, o que permite aos utilizadores desenvolver e implementar com confiança até os maiores modelos. O seu design modular e componível, baseado na filosofia da pilha de IA JAX, concede aos utilizadores uma liberdade e um controlo sem paralelo, permitindo-lhes adaptar a pilha às suas necessidades específicas sem as restrições de uma framework monolítica.
Com a XLA e os Pathways a oferecerem uma base escalável e tolerante a falhas, a JAX a oferecer uma biblioteca numérica expressiva e de elevado desempenho, bibliotecas de desenvolvimento de núcleo avançadas, como Flax, Optax, Grain e Orbax, ferramentas de desempenho avançadas, como Pallas, Tokamax e Qwix, e uma camada de produção e aplicação robusta no MaxText, vLLM e NSE, a pilha de IA JAX oferece uma base duradoura para os utilizadores criarem e levarem rapidamente a produção a investigação de vanguarda.