Otimizações de desempenho do TPU7x (Ironwood)

Este guia descreve vários métodos para otimizar a performance com TPU7x (Ironwood) gerenciando com eficiência a movimentação de dados entre o sistema de memória de vários níveis. Isso inclui técnicas como treinamento de baixa precisão, fragmentação, otimização de comunicação, rematerialização de ativação, ajuste de memória virtual com escopo e kernels de acelerador personalizados.

Para otimizar a performance com a TPU7x, primeiro você precisa conhecer a arquitetura Ironwood, especificamente a hierarquia de memória e a topologia de interconexão. Para mais informações, consulte TPU7x (Ironwood).

Treinamento de baixa precisão com FP8

FP8 (ponto flutuante de 8 bits) é um formato de dados numéricos eficiente usado principalmente para acelerar o treinamento de modelo e a inferência. Ao representar números usando 8 bits, em vez dos formatos padrão de 16 bits (FP16 ou BF16) e 32 bits (FP32), as TPUs podem processar dados muito mais rápido e usar menos memória.

A TPU7x oferece suporte à aceleração de hardware integrada para tipos de dados FP8, com um desempenho teórico máximo de 4.614 TFLOPS por chip. Isso pode levar a tempos de treinamento de ponta a ponta significativamente mais rápidos. Para operações compatíveis, principalmente multiplicações de matrizes densas comuns em cargas de trabalho de IA, o uso de FP8 pode gerar melhorias de desempenho de 1,3x em relação ao treinamento BF16 padrão. Em comparação com o BF16, o FP8 dobra os FLOPs de pico e reduz pela metade o consumo de memória para pesos e ativações. O FP8 precisa ser um mecanismo de ajuste principal para cargas de trabalho vinculadas à computação e cenários limitados pela capacidade de memória ou largura de banda.

O uso do FP8 oferece os seguintes benefícios de performance:

  • Pressão reduzida no consumo de memória de alta largura de banda (HBM): um consumo de memória menor permite que modelos maiores ou com caches de chave-valor maiores durante a inferência se encaixem totalmente nos 192 GB de HBM. Isso evita o descarregamento caro para uma memória de host mais lenta.
  • Aumento do tamanho efetivo do lote: ao reduzir a memória necessária para ativações, o FP8 permite o uso de tamanhos de lote maiores. Isso melhora o paralelismo de dados e pode levar a maior capacidade de processamento e melhor utilização das unidades de computação.
  • Requisitos de largura de banda de memória mais baixos: mover metade da quantidade de dados para cada operação reduz a demanda no caminho de dados HBM para MXU. Em sistemas em que a movimentação de dados é um gargalo comum, isso ajuda a manter as MXUs saturadas de trabalho.

Usar FP8 com degradação zero ou limitada na performance exige a seleção cuidadosa de técnicas de quantização. Confira algumas práticas recomendadas para treinamento de FP8:

  • Granularidade de escalonamento: comece com o escalonamento por tensor como linha de base. Se houver problemas de qualidade ou desempenho, mude para o escalonamento por eixo. O escalonamento de subcanais pode ser desnecessário.
  • Modo de escalonamento: o escalonamento dinâmico, que calcula fatores de escalonamento em tempo de execução, é um bom padrão para manter a qualidade. Embora o escalonamento estático possa oferecer um aumento significativo no desempenho ao eliminar cálculos, ele exige um perfilamento cuidadoso para determinar os fatores de escalonamento corretos e pode não ser adequado para todos os casos de uso, especialmente quando as configurações do modelo mudam. Por outro lado, alguns modelos e configurações robustos podem corrigir a escala para o limite FP8 de pesos ou ativações, permitindo reduzir a sobrecarga de quantização e manter a precisão e melhorar o desempenho.
  • Formatos FP8 (E4M3 e E5M2): uma abordagem comum e eficaz é usar uma combinação de formatos FP8. Por exemplo, use E4M3 para pesos e ativações na transmissão direta para aproveitar a maior precisão do E4M3 e use E5M2 para gradientes na transmissão inversa para acomodar o intervalo dinâmico mais amplo dos gradientes.
  • Arredondamento: usar "arredondar para o número par mais próximo" (RNE, na sigla em inglês) em vez de arredondamento estocástico para gradientes pode manter a qualidade e oferecer melhor desempenho e reprodutibilidade.
  • Como ativar o FP8 no MaxText: o MaxText é compatível com o treinamento FP8 pela biblioteca de quantização QWIX. Para ativar a quantização, defina a seguinte flag na sua configuração: use_qwix_quantization=true.

Fragmentação e paralelismo

O sharding é o processo de dividir um modelo grande ou os dados de treinamento dele em partes menores e distribuí-las em vários chips ou núcleos de TPU. Escolher a estratégia de fragmentação certa é importante para alcançar alta performance na TPU7x.

Uma abordagem simples que apenas maximiza o grau de paralelismo geralmente resulta em desempenho ruim por se tornar vinculada à comunicação. A melhor abordagem geralmente é selecionar a estratégia de fragmentação mais simples que atenda às restrições de memória, já que isso minimiza a sobrecarga de comunicação e permite que as unidades de computação sejam usadas de maneira eficiente.

Antes de selecionar uma estratégia de fragmentação, a primeira etapa de qualquer ajuste de performance deve ser uma análise de intensidade aritmética. Essa análise determina se uma determinada computação é limitada por computação, largura de banda de memória ou largura de banda de interconexão. Ela é calculada como a proporção de operações de usar pontos flutuantes para os bytes de dados que precisam ser movidos.

Uma alta intensidade aritmética indica uma carga de trabalho vinculada à computação. Uma baixa intensidade aritmética sugere uma carga de trabalho limitada pela memória ou pela comunicação, em que o desempenho é limitado pela velocidade com que os dados podem ser movidos da HBM ou pela rede ICI. Essa análise informa o tamanho ideal do lote e a estratégia de fragmentação. Por exemplo, uma carga de trabalho vinculada à comunicação não se beneficia de uma estratégia de fragmentação que introduz ainda mais comunicação, como o paralelismo de tensores de alto grau.

Framework de decisão da estratégia de fragmentação

O MaxText oferece várias estratégias de fragmentação. A escolha ideal depende da arquitetura do modelo, do comprimento da sequência e da necessidade de equilibrar a carga computacional com a sobrecarga de comunicação.

  • Paralelismo de dados totalmente fragmentados (FSDP): essa é a estratégia padrão preferida para paralelismo de dados. O FSDP fragmenta os pesos do modelo, os gradientes e os estados do otimizador entre os dispositivos paralelos de dados. Durante a computação, cada dispositivo realiza uma operação All-Gather para recuperar os pesos completos necessários para o microlote local. O FSDP é altamente eficaz desde que o tamanho do lote por dispositivo seja grande o suficiente para ocultar a latência dessa comunicação All-Gather. Para modelos de combinação de especialistas (MoE, na sigla em inglês), o cálculo da intensidade aritmética precisa considerar a escassez.
  • Paralelismo de tensor (TP): o TP fragmenta tensores individuais em vários dispositivos. Normalmente, os tensores são matrizes de peso em perceptrons multicamadas (MLP) e blocos de atenção. A alta intensidade aritmética do hardware (11, 5 mil) impõe uma exigência muito alta nas dimensões do modelo para tornar o TP viável em relação ao ICI.Tentar usar o TP pode resultar em um sistema limitado pela comunicação.
  • Paralelismo de especialistas (EP): essa é a estratégia padrão e necessária para treinar modelos MoE. O EP fragmenta as camadas "especialistas" em um conjunto de dispositivos, e um coletivo de comunicação de todos para todos é usado para rotear tokens para o dispositivo especialista designado. A EP pode ser eficiente se a dimensão MLP do modelo for grande o suficiente para se aproximar do roofline.
  • Paralelismo de contexto (CP): o CP é uma estratégia especializada essencial para treinar modelos com sequências muito longas. A principal função dele é gerenciar o consumo de memória das ativações, que cresce quadraticamente com o comprimento da sequência e pode exceder a capacidade da HBM. A CP fragmenta a dimensão de sequência dos tensores de ativação, o que permite o uso de um tamanho de lote fracionário por dispositivo. Como o CP introduz mais comunicação do que o FSDP, a regra geral é usar o grau mínimo de CP necessário para satisfazer as restrições de memória e garantir que o fragmento do eixo de lote permaneça um número inteiro.

A tabela a seguir associa tipos de carga de trabalho comuns à estratégia de fragmentação ideal:

Tipo de carga de trabalho Fragmentação principal recomendada Fragmentação secundária Principais gargalos Justificativa
Modelo denso: sequência curta FSDP N/A Rematerialização, FF Matmuls O FSDP oferece o melhor equilíbrio. Com sequências curtas, a memória de ativação pode não ser um problema. A chave é um lote global grande o suficiente para ocultar o All-Gather de peso do FSDP. À medida que o tamanho do lote aumenta, o tamanho da ativação também aumenta, e uma política de rematerialização adequada é necessária para garantir que essa configuração não fique sem memória.
Modelo denso: sequência longa FSDP CP Atenção rápida, memória de ativação A memória de ativação se torna a restrição principal. O CP é necessário para ativar tamanhos de lote fracionários por dispositivo e evitar problemas de falta de memória (OOMs). A atenção rápida é a principal fonte de computação e tempo perdido.
Modelo MoE: sequência curta FSDP + EP N/A De todos para todos (roteamento especializado), rematerialização Os modelos MoE exigem EP para fragmentar os especialistas. A comunicação de todos para todos para o roteamento de tokens é um grande gargalo que precisa ser sobreposto. A rematerialização também é uma fonte significativa de desperdício.
Modelo MoE: escala muito grande FSDP + EP + PP Paralelismo de modelos (MP) Todos os gargalos mencionados anteriormente, além de bolhas de pipeline Para modelos que excedem a memória de um único pod, o PP é necessário para fragmentar camadas em pods. Isso introduz a comunicação DCN e os overheads de bolha do pipeline. Essa é uma configuração altamente complexa que exige ajuste cuidadoso.

Otimização da comunicação

O principal mecanismo para sobreposição de comunicação e computação na TPU7x é chamado de SparseCore Collective Offloading. A arquitetura do Ironwood inclui unidades SparseCore dedicadas, que atuam como linhas de execução de controle independentes capazes de gerenciar a movimentação de dados na malha ICI. Isso permite que operações de comunicação coletiva (como All-Gather ou Reduce-Scatter) sejam executadas em paralelo com as principais computações que acontecem nos TensorCores. Esse é o método recomendado para coletivos assíncronos em TPU7x. Use as flags recomendadas para ativar o descarregamento dos coletivos mais comuns.

Rematerialização de ativação

A rematerialização de ativação, também conhecida como checkpointing de gradiente, é uma técnica fundamental para reduzir a pegada de HBM de um modelo. Em vez de armazenar todas as ativações intermediárias da passagem direta em HBM para serem usadas durante a passagem reversa, ele salva apenas algumas ativações principais (pontos de verificação) e recalcula as outras sob demanda durante a passagem reversa. Isso economiza uma quantidade significativa de memória ao custo de maior computação (aproximadamente 25 a 30% de FLOPs adicionais para um bloco de transformador padrão).

A decisão de como aplicar a rematerialização de forma agressiva é um parâmetro de ajuste crítico que depende inteiramente do gargalo principal, que geralmente varia com o comprimento da sequência.

Para cargas de trabalho de sequência longa (como 128k): nesses casos, o tamanho dos tensores de ativação é o consumidor dominante de HBM. A carga de trabalho geralmente é limitada pela memória. Portanto, aplicar uma política de rematerialização agressiva é muito benéfico. A economia de memória permite que o treinamento prossiga sem erros de falta de memória e também permite tamanhos de lote maiores. Além disso, a sobrecarga computacional de recálculo é uma troca válida.

Para cargas de trabalho de sequência curta (como 8k): nesses casos, a memória de ativação é muito menos preocupante, e a carga de trabalho tem mais probabilidade de ser limitada por computação. A sobrecarga computacional da rematerialização pode ser a maior fonte de ineficiência.

Ajustar políticas de rematerialização no MaxText

O MaxText oferece controle granular sobre a rematerialização por um conjunto de políticas predefinidas e personalizadas, configuradas usando a flag remat_policy.

Políticas predefinidas

O MaxText oferece as seguintes políticas integradas:

  • full: a política mais agressiva, que materializa quase tudo. Isso minimiza o uso de HBM, mas maximiza a sobrecarga de recálculo. Ideal para cenários de sequência longa com restrições extremas de memória.
  • minimal: a política menos agressiva, que armazena a maioria das ativações. Isso maximiza o uso da HBM, mas minimiza o recálculo. Ideal para cargas de trabalho de sequência curta e vinculadas à computação em que a memória não é um problema.
  • Políticas intermediárias: opções como save_dot_with_context_except_mlp, save_qkv_proj e save_out_proj oferecem várias compensações ao fazer o checkpoint seletivo das saídas de operações de produto escalar caras e materializar novamente operações mais baratas elemento a elemento.

Políticas personalizadas

Para ter mais controle, defina remat_policy como custom. Isso permite especificar o comportamento de camadas individuais no módulo de decodificação do modelo. Cada camada pode receber um dos três comportamentos:

  • device: a ativação é armazenada na HBM no dispositivo de TPU.
  • remat: a ativação é descartada e será rematerializada durante a transmissão de volta.
  • offload: a ativação é movida da HBM para a memória do host da CPU, liberando a HBM ao custo da latência de transferência do PCIe.

Ajuste do VMEM com escopo

O desempenho do kernel, como a atenção rápida, depende dos tamanhos de bloco selecionados no kernel, cujo tamanho é limitado pela memória de vetor disponível (VMEM). Cada um dos dois TensorCores em um chip TPU7x tem 64 MiB de memória vetorial (VMEM). Essa capacidade de VMEM pode ser dividida entre o escopo atual (VMEM no escopo) e o pré-carregamento de peso futuro. Aumentar a VMEM no escopo permite aumentar os tamanhos dos blocos no kernel, reduzindo potencialmente as interrupções de memória e aumentando o desempenho dos kernels. É possível alterar o tamanho da VMEM no escopo definindo xla_tpu_scoped_vmem_limit_kib (em LIBTPU_INIT_ARGS), que pode ser usado para analisar o desempenho do kernel e os limites de desempenho de ponta a ponta. A otimização do tamanho do VMEM no escopo pode afetar indiretamente o desempenho do kernel Pallas personalizado, já que o aumento do VMEM no escopo desbloqueia um espaço de pesquisa de hiperparâmetros maior para tamanhos de bloco no kernel.

Kernels Tokamax

O Tokamax, uma biblioteca de kernels JAX de alta performance com muitos kernels de TPU altamente otimizados, resolve vários gargalos comuns específicos de hardware:

  • Atenção de splash: é usada como a principal implementação de atenção para eliminar o gargalo de HBM da atenção padrão e usa a implementação de atenção mais eficiente em TPUs.
  • Multiplicação de matrizes agrupadas (GMM) do Megablox: para cargas de trabalho de MoE, o Megablox processa com eficiência as multiplicações de matrizes agrupadas calculando a representação de ativações irregulares. Ele mapeia com eficiência a dimensão irregular, computando multiplicações de matrizes entre grupos irregulares de linhas no lado esquerdo e a matriz de especialistas correspondente, evitando a necessidade de adicionar lotes a um tamanho fixo.
  • Ajuste empírico com tune-jax: a biblioteca tune-jax tem utilitários para realizar pesquisas empíricas de tamanhos de bloco ideais. Os tamanhos de kernel padrão geralmente não são ideais. O ajuste permite escolher tamanhos de bloco de VMEM compatíveis com hardware para maximizar a utilização do hardware.
  • Estimativa de logits máximos: o kernel de atenção do Tokamax Splash pode ser ainda mais otimizado definindo um valor para max_logit_const. Se definido, ele substitui o cálculo de redução do logit máximo durante a operação softmax de atenção (softmax(Q * KT)), reduzindo parte da sobrecarga computacional e de sincronização. No MaxText, ele é implementado pela configuração use_max_logits_estimate, que pode ser definida como None (desativada) ou um pontuação flutuante. Verifique se o intervalo de logits do seu modelo específico continua compatível com a estimativa para evitar estouro numérico. O teste de convergência é recomendado se esse valor for definido.