Cómo escalar un modelo en TPUs

En este documento, se proporciona una descripción general de cómo escalar modelos de lenguaje: cómo funcionan las TPU y cómo se comunican entre sí, cómo se ejecutan los LLM en hardware real y cómo paralelizar los modelos durante el entrenamiento y la inferencia para que se ejecuten de forma eficiente a gran escala. Proporcionamos información que te ayuda a evaluar qué tan costoso será entrenar un LLM, cuánta memoria necesitas para publicarlo y cómo fragmentar modelos de manera eficaz en varias TPU.

Si bien gran parte del aprendizaje profundo es complejo, optimizar el rendimiento de tus modelos no tiene que serlo, incluso a gran escala. Los principios fundamentales se aplican en todas partes, desde un solo acelerador hasta decenas de miles, y comprenderlos te permite hacer muchas cosas útiles:

  • Estima qué tan cerca están las partes de tu modelo de su valor óptimo teórico.
  • Toma decisiones fundamentadas sobre diferentes esquemas de paralelismo en diferentes escalas (cómo divides el procesamiento en varios dispositivos).
  • Estima el costo y el tiempo necesarios para entrenar y ejecutar modelos Transformer grandes.
  • Diseñar algoritmos que aprovechen la arquitectura de las TPU
  • Diseñar arquitecturas de modelos basadas en una comprensión explícita de lo que limita el rendimiento del algoritmo

Requisitos previos

Debes tener un conocimiento básico de los LLM y la arquitectura de Transformer, pero no necesariamente de cómo operan a gran escala. Debes comprender los conceptos básicos del entrenamiento de LLM y, de manera ideal, tener cierta familiaridad con JAX. La lectura de antecedentes útiles para la arquitectura de Transformer incluye lo siguiente:

Después de familiarizarte con estos requisitos previos, deberías sentirte cómodo estimando el mejor esquema de paralelismo para un modelo Transformer en una plataforma de TPU determinada. También podrás estimar cuánto tiempo deberían llevar el entrenamiento y la inferencia.

Importancia del ajuste de escala del modelo

Actualmente, los LLM y la mayoría de los modelos pequeños se ejecutan tan cerca de los límites de hardware que el desarrollo de modelos requiere que pienses en la eficiencia a gran escala. Un aumento del 20% en las comparativas no es relevante si implica un costo del 20% en la eficiencia de la capacidad máxima de procesamiento. Las arquitecturas de modelos prometedoras suelen fallar porque no pueden ejecutarse de manera eficiente a gran escala o porque no se realizan los esfuerzos de optimización necesarios para que lo hagan.

El objetivo del escalamiento del modelo es poder aumentar la cantidad de chips utilizados para el entrenamiento o la inferencia y, al mismo tiempo, lograr un aumento proporcional y lineal en la capacidad de procesamiento. Esto se conoce como escalamiento fuerte. Si bien agregar chips adicionales (paralelismo) suele reducir el tiempo de procesamiento, también implica un costo de comunicación adicional entre los chips. Cuando la comunicación lleva más tiempo que el cálculo, el modelo se limita a la comunicación y no se puede escalar bien. Comprender el hardware lo suficiente como para anticipar dónde surgirán estos cuellos de botella te permite diseñar o reconfigurar tus modelos para evitarlos.

En las siguientes secciones, se proporciona una descripción general de cómo escalar el hardware de TPU y cómo evolucionó la arquitectura de Transformer. Esta información es útil tanto para los investigadores que diseñan arquitecturas nuevas como para los ingenieros que trabajan para que la generación actual de LLMs se ejecute rápidamente.

Parte 1: Conceptos

En esta parte, se explica el análisis de techo y los factores que limitan la capacidad de un modelo para escalar (comunicación, procesamiento y memoria). A continuación, describimos cómo funcionan las TPU, tanto como chips individuales como, lo que es de importancia crítica, como un sistema interconectado con vínculos entre chips de ancho de banda y latencia limitados.

Parte 2: Cómo escalar Transformers

Es importante comprender cada parte de la arquitectura de Transformer: los tamaños exactos de cada matriz, dónde se produce la normalización y cuántos parámetros y FLOPS hay en cada parte. En esta parte, se explica cuidadosamente la matemática de este transformador y se muestra cómo contar los parámetros y las FLOP para el entrenamiento y la inferencia. Esto te indica cuánta memoria usará tu modelo, cuánto tiempo dedicarás a la computación o las comunicaciones, y cuándo la atención se volverá importante en relación con los bloques de avance.

Por último, esta parte te ayuda a responder la pregunta fundamental: dado un modelo de un tamaño específico y una cantidad determinada de chips, ¿cómo paralelizar el modelo para mantener la condición de escalamiento fuerte? Para responder esta pregunta, en esta parte se analizan las cuatro técnicas principales de paralelismo que se usan para dividir los modelos en varios chips: datos, tensor, canalización y expertos. También describe otras técnicas para reducir los requisitos de memoria, como la rematerialización, la fragmentación del modelo potenciada por ZeRO, la descarga del host y la acumulación de gradientes.

  • Introducción a las operaciones matemáticas de Transformer: En esta sección, se explica el proceso matemático para responder preguntas sobre la cantidad de FLOPS que usa un Transformer durante los pases hacia adelante y hacia atrás, los cálculos para computar la cantidad de parámetros y el tamaño de los cachés de KV.
  • Paralelización de Transformer para el entrenamiento: En esta sección, se detalla el proceso para maximizar la eficiencia del entrenamiento coordinando el FSDP, la fragmentación de Megatron y el paralelismo de canalización. Se describe cómo determinar la distribución óptima para un tamaño de modelo y un tamaño de lote específicos en una cantidad fija de chips para lograr la capacidad de procesamiento máxima.
  • Ajuste de escala del Transformer para la inferencia: Después de entrenar un modelo, se debe entregar. La inferencia agrega una nueva consideración, la latencia, y cambia el panorama de la memoria. En esta sección, se describe cómo funciona la publicación desagregada y cómo pensar en las memorias caché de KV.
    • Cómo publicar Llama 3 en TPUs: En esta subsección, se describe cómo publicar Llama 3 en TPUs, cuánto podría costar y las compensaciones entre latencia y capacidad de procesamiento.

Parte 3: Implementación práctica

En esta parte, se describe cómo implementar los conceptos de escalamiento con JAX y cómo generar perfiles y depurar tu código cuando las cosas no funcionan.

  • Creación de perfiles de programas para TPU: Los LLM reales son complejos y difíciles de desarrollar, optimizar y depurar. En esta sección, se explica la pila de JAX + XLA y cómo usar el profiler de JAX/TensorBoard para depurar y corregir problemas reales.
  • Cómo programar TPUs en JAX: En esta sección, se describe cómo usar las APIs de JAX para paralelizar el procesamiento.