Mettre à l'échelle un modèle sur des TPU

Ce document présente la mise à l'échelle des modèles de langage : comment les TPU fonctionnent et communiquent entre eux, comment les LLM s'exécutent sur du matériel réel, et comment paralléliser vos modèles pendant l'entraînement et l'inférence afin qu'ils s'exécutent efficacement à grande échelle. Nous vous fournissons des informations qui vous aident à évaluer le coût de l'entraînement d'un LLM, la quantité de mémoire dont vous avez besoin pour diffuser le modèle et la façon de fragmenter efficacement les modèles sur plusieurs TPU.

Bien que le deep learning soit souvent complexe, l'optimisation des performances de vos modèles ne doit pas l'être, même à grande échelle. Les principes fondamentaux s'appliquent partout, qu'il s'agisse d'un seul accélérateur ou de dizaines de milliers. Les comprendre vous permet de faire de nombreuses choses utiles :

  • Estimez la proximité des parties de votre modèle par rapport à leur optimum théorique.
  • Faites des choix éclairés concernant les différents schémas de parallélisme à différentes échelles (comment répartir le calcul sur plusieurs appareils).
  • Estimez le coût et le temps nécessaires pour entraîner et exécuter de grands modèles Transformer.
  • Concevez des algorithmes qui tirent parti de l'architecture TPU.
  • Concevez des architectures de modèles basées sur une compréhension explicite des limites des performances de l'algorithme.

Prérequis

Vous devez avoir une compréhension de base des LLM et de l'architecture Transformer, mais pas nécessairement de leur fonctionnement à grande échelle. Vous devez comprendre les bases de l'entraînement des LLM et, idéalement, avoir une connaissance de base de JAX. Voici quelques lectures de base utiles pour l'architecture Transformer :

Une fois que vous vous serez familiarisé avec ces conditions préalables, vous devriez être en mesure d'estimer le meilleur schéma de parallélisme pour un modèle Transformer sur une plate-forme TPU donnée. Vous pourrez également estimer la durée de l'entraînement et de l'inférence.

Importance de la mise à l'échelle des modèles

Les LLM et la plupart des petits modèles actuels fonctionnent si près des limites matérielles que le développement de modèles nécessite de penser à l'efficacité à grande échelle. Une amélioration de 20 % par rapport aux benchmarks n'a aucune importance si elle s'accompagne d'une perte de 20 % en termes d'efficacité de la couverture. Les architectures de modèles prometteuses échouent régulièrement, soit parce qu'elles ne peuvent pas s'exécuter efficacement à grande échelle, soit parce qu'elles ne sont pas suffisamment optimisées pour le faire.

L'objectif du scaling de modèle est d'augmenter le nombre de puces utilisées pour l'entraînement ou l'inférence tout en obtenant une augmentation proportionnelle et linéaire du débit. C'est ce qu'on appelle la mise à l'échelle forte. Bien que l'ajout de puces (parallélisme) réduise généralement le temps de calcul, il entraîne également un coût de communication supplémentaire entre les puces. Lorsque la communication prend plus de temps que le calcul, le modèle devient lié à la communication et ne peut pas évoluer correctement. Comprendre suffisamment bien le matériel pour anticiper où ces goulots d'étranglement se produiront vous permet de concevoir ou de reconfigurer vos modèles pour les éviter.

Les sections suivantes présentent la mise à l'échelle du matériel TPU et l'évolution de l'architecture Transformer. Ces informations sont utiles aux chercheurs qui conçoivent de nouvelles architectures et aux ingénieurs qui s'efforcent d'accélérer l'exécution de la génération actuelle de LLM.

Partie 1 : Concepts

Cette partie explique l'analyse roofline et les facteurs qui limitent la capacité d'un modèle à évoluer (communication, calcul et mémoire). Nous décrirons ensuite le fonctionnement des TPU, à la fois en tant que puces individuelles et, surtout, en tant que système interconnecté avec des liaisons entre les puces dont la bande passante et la latence sont limitées.

Partie 2 : Mettre à l'échelle les Transformers

Il est important de comprendre chaque élément de l'architecture Transformer : la taille exacte de chaque matrice, où se produit la normalisation, et le nombre de paramètres et de FLOP dans chaque partie. Cette partie examine attentivement les mathématiques de Transformer, en montrant comment compter les paramètres et les FLOP pour l'entraînement et l'inférence. Cela vous indique la quantité de mémoire que votre modèle utilisera, le temps que vous passerez sur le calcul ou les communications, et le moment où l'attention deviendra importante par rapport aux blocs feed-forward.

Enfin, cette partie vous aide à répondre à la question fondamentale suivante : étant donné un modèle d'une taille spécifique et un certain nombre de puces, comment paralléliser le modèle pour rester dans la condition de scaling fort ? Pour répondre à cette question, cette partie aborde les quatre principales techniques de parallélisation utilisées pour répartir les modèles sur plusieurs puces : données, tenseur, pipeline et expert. Il décrit également d'autres techniques permettant de réduire les besoins en mémoire, telles que la rematérialisation, le partitionnement de modèle optimisé par ZeRO, le déchargement de l'hôte et l'accumulation de gradient.

  • Présentation des opérations mathématiques de Transformer : cette section explique les calculs mathématiques permettant de répondre aux questions sur le nombre de FLOPs utilisés par un Transformer lors des passes avant et arrière, les calculs permettant de déterminer le nombre de paramètres et la taille des caches KV.
  • Parallélisation des Transformers pour l'entraînement : cette section décrit en détail le processus permettant de maximiser l'efficacité de l'entraînement en coordonnant FSDP, la segmentation Megatron et le parallélisme de pipeline. Il explique comment déterminer la distribution optimale pour une taille de modèle et une taille de lot spécifiques sur un nombre fixe de puces afin d'atteindre un débit maximal.
    • Entraîner Llama 3 sur des TPU : cette sous-section explique comment entraîner Llama 3 sur des TPU, combien de temps cela peut prendre et combien cela peut coûter.
  • Mise à l'échelle des Transformers pour l'inférence : une fois qu'un modèle est entraîné, il doit être diffusé. L'inférence ajoute une nouvelle considération, la latence, et modifie le paysage de la mémoire. Cette section décrit le fonctionnement du serving désagrégé et la façon de concevoir les caches KV.
    • Diffuser Llama 3 sur des TPU : cette sous-section décrit comment diffuser Llama 3 sur des TPU, le coût potentiel, ainsi que les compromis entre latence et débit.

Partie 3 : Mise en œuvre pratique

Cette partie explique comment implémenter les concepts de scaling à l'aide de JAX, et comment profiler et déboguer votre code en cas de problème.

  • Profiler les programmes TPU : les LLM réels sont complexes et difficiles à développer, à optimiser et à déboguer. Cette section explique la pile JAX+XLA et comment utiliser le profileur JAX/TensorBoard pour déboguer et résoudre de vrais problèmes.
  • Programmer des TPU dans JAX : cette section explique comment utiliser les API JAX pour paralléliser le calcul.