Escalar um modelo em TPUs

Este documento oferece uma visão geral de como escalonar modelos de linguagem: como as TPUs funcionam e se comunicam entre si, como os LLMs são executados em hardware real e como paralelizar seus modelos durante o treinamento e a inferência para que eles sejam executados de maneira eficiente em grande escala. Fornecemos informações que ajudam você a avaliar o custo do treinamento de um LLM, a quantidade de memória necessária para disponibilizar o modelo e como fragmentar modelos de maneira eficaz em várias TPUs.

Embora grande parte do aprendizado profundo seja complexa, a otimização do desempenho dos modelos não precisa ser, mesmo em escala. Os princípios fundamentais se aplicam a todos os lugares — desde lidar com um único acelerador até dezenas de milhares — e a compreensão deles permite que você faça muitas coisas úteis:

  • Estimar a proximidade das partes do modelo ao ideal teórico.
  • Fazer escolhas informadas sobre diferentes esquemas de paralelismo em diferentes escalas (como você divide a computação em vários dispositivos).
  • Estimar o custo e o tempo necessários para treinar e executar modelos grandes do Transformer.
  • Criar algoritmos que aproveitem a arquitetura da TPU.
  • Criar arquiteturas de modelo orientadas por uma compreensão explícita do que limita o desempenho do algoritmo.

Pré-requisitos

Você precisa ter um conhecimento básico de LLMs e da arquitetura do Transformer, mas não necessariamente de como eles operam em escala. Você precisa entender os conceitos básicos do treinamento de LLMs e, idealmente, ter alguma familiaridade com o JAX. A leitura de informações básicas úteis para a arquitetura do Transformer inclui:

Depois de se familiarizar com esses pré-requisitos, você poderá estimar o melhor esquema de paralelismo para um modelo do Transformer em uma determinada plataforma de TPU. Você também poderá estimar quanto tempo o treinamento e a inferência devem levar.

Importância do escalonamento de modelos

Os LLMs e a maioria dos modelos pequenos atuais são executados tão perto dos limites de hardware que o desenvolvimento de modelos exige que você pense na eficiência em escala. Uma vitória de 20% em comparativos é irrelevante se ela tiver um custo de 20% para a eficiência do roofline. Arquiteturas de modelos promissoras falham rotineiramente porque não podem ser executadas de maneira eficiente em escala ou devido à falta de esforço de otimização para que isso aconteça.

O objetivo do escalonamento de modelos é aumentar o número de chips usados para treinamento ou inferência, ao mesmo tempo em que se alcança um aumento proporcional e linear na capacidade de processamento. Isso é conhecido como escalonamento forte. Embora a adição de chips (paralelismo) geralmente diminua o tempo de computação, ela também tem o custo de comunicação adicional entre os chips. Quando a comunicação leva mais tempo do que a computação, o modelo fica vinculado à comunicação e não pode ser escalonado bem. Entender o hardware bem o suficiente para prever onde esses gargalos vão surgir permite que você crie ou reconfigure seus modelos para evitá-los.

As seções a seguir oferecem uma visão geral de como escalonar o hardware da TPU e como a arquitetura do Transformer evoluiu. Essas informações são úteis para pesquisadores que criam novas arquiteturas e engenheiros que trabalham para que a geração atual de LLMs seja executada rapidamente.

Parte 1: conceitos

Esta parte explica a análise de roofline e os fatores que limitam a capacidade de um modelo de escalonamento (comunicação, computação e memória). Em seguida, descrevemos como as TPUs funcionam, tanto como chips individuais quanto, o que é de importância crítica, como um sistema interconectado com links entre chips de largura de banda e latência limitadas.

Parte 2: escalonamento de transformadores

É importante entender cada parte da arquitetura do Transformer: os tamanhos exatos de cada matriz, onde a normalização ocorre e quantos parâmetros e FLOPs estão em cada parte. Esta parte analisa cuidadosamente essa matemática do Transformer, mostrando como contar os parâmetros e FLOPs para treinamento e inferência. Isso informa quanta memória seu modelo vai usar, quanto tempo você vai gastar em computação ou comunicações e quando a atenção se tornará importante em relação aos blocos de feedforward.

Por fim, esta parte ajuda você a responder à pergunta fundamental: dado um modelo de um tamanho específico e fornecido com um determinado número de chips, como paralelizar o modelo para permanecer na condição de escalonamento forte. Para responder a essa pergunta, esta parte discute as quatro técnicas principais de paralelismo usadas para dividir modelos em vários chips: dados, tensor, pipeline e especialista. Ela também descreve outras técnicas para reduzir os requisitos de memória, como rematerialização, fragmentação de modelos com tecnologia ZeRO, descarregamento de host e acúmulo de gradiente.

  • Introdução às operações matemáticas do Transformer: esta seção explica a matemática para responder a perguntas sobre o número de FLOPs usados por um Transformer durante as passagens direta e reversa, cálculos para computar o número de parâmetros e o tamanho dos caches KV.
  • Paralelização do Transformer para treinamento: esta seção detalha o processo para maximizar a eficiência do treinamento coordenando o FSDP, a fragmentação do Megatron e o paralelismo de pipeline. Ela descreve como determinar a distribuição ideal para um tamanho de modelo e tamanho do lote específicos em um número fixo de chips para alcançar a capacidade de processamento máxima.
  • Escalonamento do Transformer para inferência: depois que um modelo é treinado, ele precisa ser veiculado. A inferência adiciona uma nova consideração, a latência, e muda o cenário de memória. Esta seção descreve como a veiculação desagregada funciona e como pensar em caches KV.
    • Disponibilização do Llama 3 em TPUs: esta subseção descreve como disponibilizar o Llama 3 em TPUs, quanto isso pode custar e as compensações de latência e capacidade de processamento.

Parte 3: implementação prática

Esta parte descreve como implementar os conceitos de escalonamento usando o JAX e como criar perfis e depurar seu código quando as coisas dão errado.

  • Como criar perfis de programas de TPU: os LLMs reais são complexos e difíceis de desenvolver, otimizar e depurar. Esta seção explica a pilha JAX + XLA e como usar o profiler JAX/TensorBoard para depurar e corrigir problemas reais.
  • Programação de TPUs no JAX: Esta seção descreve como usar as APIs JAX para paralelizar a computação.