Optimisations des performances du TPU7x (Ironwood)
Ce guide décrit plusieurs méthodes permettant d'optimiser les performances avec TPU v4 (Ironwood) en gérant efficacement le transfert de données entre son système de mémoire à plusieurs niveaux. Cela inclut des techniques telles que l'entraînement de faible précision, le partitionnement, l'optimisation de la communication, la rematérialisation de l'activation, le réglage de la mémoire virtuelle à portée limitée et les noyaux d'accélérateur personnalisés.
Pour optimiser les performances avec TPU7x, vous devez d'abord vous familiariser avec l'architecture Ironwood, en particulier la hiérarchie de mémoire et la topologie d'interconnexion. Pour en savoir plus, consultez TPU7x (Ironwood).
Entraînement de faible précision avec FP8
FP8 (virgule flottante de 8 bits) est un format de données numériques efficace utilisé principalement pour accélérer l'entraînement et l'inférence des modèles. En représentant les nombres à l'aide de 8 bits, au lieu des formats standard de 16 bits (FP16 ou BF16) et de 32 bits (FP32), les TPU peuvent traiter les données beaucoup plus rapidement et utiliser moins de mémoire.
TPU7x est compatible avec l'accélération matérielle intégrée pour les types de données FP8, offrant une performance théorique maximale de 4 614 TFLOPS par puce. Cette fonctionnalité peut réduire considérablement les temps d'entraînement de bout en bout. Pour les opérations compatibles, en particulier les multiplications de matrices denses courantes pour les charges de travail d'IA, l'utilisation de FP8 peut améliorer les performances de 1,3 fois par rapport à l'entraînement BF16 standard. Par rapport à BF16, FP8 double les FLOPS de pointe et divise par deux l'empreinte mémoire pour les pondérations et les activations. FP8 doit être un levier de réglage principal pour les charges de travail liées au calcul et les scénarios limités par la capacité ou la bande passante de la mémoire.
L'utilisation de FP8 offre les avantages suivants en termes de performances :
- Pression réduite sur la mémoire à haut débit (HBM) : une empreinte mémoire plus petite permet aux modèles plus grands ou aux modèles avec des caches KV plus grands pendant l'inférence de tenir entièrement dans les 192 Go de HBM. Cela évite un déchargement coûteux vers une mémoire hôte plus lente.
- Taille de lot effective accrue : en réduisant la mémoire requise pour les activations, FP8 permet d'utiliser des tailles de lot plus importantes. Cela améliore le parallélisme des données et peut entraîner un débit plus élevé et une meilleure utilisation des unités de calcul.
- Exigences de bande passante mémoire réduites : le déplacement de la moitié de la quantité de données pour chaque opération réduit la demande sur le chemin de données HBM vers MXU. Sur les systèmes où le déplacement de données est un goulot d'étranglement courant, cela permet de maintenir les MXU saturées de travail.
Pour utiliser FP8 avec une dégradation des performances nulle ou limitée, il est nécessaire de sélectionner avec soin les techniques de quantification. Voici quelques bonnes pratiques à prendre en compte pour l'entraînement FP8 :
- Granularité de la mise à l'échelle : commencez par la mise à l'échelle par Tensor comme référence. En cas de problème de qualité ou de performances, passez à la mise à l'échelle par axe. La mise à l'échelle des sous-canaux n'est peut-être pas nécessaire.
- Mode Scaling : le scaling dynamique, qui calcule les facteurs de scaling au moment de l'exécution, est une bonne option par défaut pour maintenir la qualité. Bien que le scaling statique puisse améliorer considérablement les performances en éliminant les calculs, il nécessite un profilage minutieux pour déterminer les facteurs de scaling appropriés et peut ne pas convenir à tous les cas d'utilisation, en particulier lorsque les configurations de modèle changent. À l'inverse, certains modèles et configurations robustes peuvent fixer l'échelle à la limite FP8 pour les pondérations ou les activations, ce qui vous permet de réduire la surcharge de quantification tout en conservant la précision et en améliorant les performances.
- Formats FP8 (E4M3 et E5M2) : une approche courante et efficace consiste à utiliser un mélange de formats FP8. Par exemple, utilisez E4M3 pour les pondérations et les activations dans la passe avant afin de profiter de la précision supérieure d'E4M3, et utilisez E5M2 pour les gradients dans la passe arrière afin de tenir compte de la plage dynamique plus large des gradients.
- Arrondi : l'utilisation de l'arrondi au nombre pair le plus proche (RNE, round to nearest even) au lieu de l'arrondi stochastique pour les gradients peut maintenir la qualité tout en offrant de meilleures performances et une meilleure reproductibilité.
- Activer FP8 dans MaxText :
MaxText est compatible avec l'entraînement FP8
via la bibliothèque de quantification QWIX. Pour activer la quantification, définissez l'indicateur suivant dans votre configuration :
use_qwix_quantization=true.
Partitionnement et parallélisme
Le partitionnement consiste à découper un grand modèle ou ses données d'entraînement en plus petits morceaux et à les répartir sur plusieurs puces ou cœurs TPU. Il est important de choisir la bonne stratégie de partitionnement pour obtenir des performances élevées sur les TPU7x.
Une approche naïve qui maximise uniquement le degré de parallélisme entraînera souvent de mauvaises performances en devenant liée à la communication. La meilleure approche consiste souvent à sélectionner la stratégie de partitionnement la plus simple qui respecte les contraintes de mémoire, car cela minimise la surcharge de communication et permet d'utiliser efficacement les unités de calcul.
Avant de sélectionner une stratégie de partitionnement, la première étape de tout effort d'optimisation des performances doit être une analyse de l'intensité arithmétique. Cette analyse détermine si un calcul donné est limité par la puissance de calcul, la bande passante mémoire ou la bande passante d'interconnexion. Il est calculé comme le rapport entre les opérations à virgule flottante et les octets de données à déplacer.
Une intensité arithmétique élevée indique une charge de travail liée au calcul. Une faible intensité arithmétique suggère une charge de travail liée à la mémoire ou à la communication, où les performances sont limitées par la vitesse à laquelle les données peuvent être déplacées depuis la mémoire HBM ou sur le réseau ICI. Cette analyse permet de déterminer la taille de lot et la stratégie de partitionnement idéales. Par exemple, une charge de travail liée à la communication ne bénéficiera pas d'une stratégie de partitionnement qui introduit encore plus de communication, comme le parallélisme tensoriel de haut degré.
Cadre de décision pour la stratégie de sharding
MaxText propose différentes stratégies de sharding. Le choix optimal dépend de l'architecture du modèle, de la longueur de la séquence et de la nécessité d'équilibrer la charge de calcul par rapport à la surcharge de communication.
- Parallélisme des données entièrement segmentées (FSDP) : il s'agit de la stratégie par défaut privilégiée pour le parallélisme des données. FSDP segmente les pondérations, les gradients et les états de l'optimiseur du modèle sur les appareils parallèles aux données. Lors du calcul, chaque appareil effectue une opération All-Gather pour récupérer les poids complets nécessaires à son microbatch local. Le FSDP est très efficace tant que la taille de lot par appareil est suffisamment grande pour masquer la latence de cette communication All-Gather. Pour les modèles Mixture-of-Experts (MoE), le calcul de l'intensité arithmétique doit tenir compte de la parcimonie.
- Parallélisme Tensor (TP) : le TP segmente les Tensors individuels sur plusieurs appareils. En général, les Tensors sont des matrices de pondération dans les blocs de perceptron multicouche (MLP) et d'attention. L'intensité arithmétique élevée du matériel (11, 5 k) impose une exigence très élevée sur les dimensions du modèle pour rendre le TP viable par rapport à l'ICI.Tenter d'utiliser le TP peut entraîner une limitation des communications du système.
- Parallélisme expert (EP) : il s'agit de la stratégie standard et nécessaire pour entraîner les modèles MoE. EP fragmente les couches "expert" sur un ensemble d'appareils, et un collectif de communication All-to-All est utilisé pour acheminer les jetons vers leur appareil expert désigné. L'EP peut être efficace si la dimension MLP du modèle est suffisamment grande pour approcher la roofline.
- Parallélisme de contexte (CP) : le CP est une stratégie spécialisée qui est essentielle pour entraîner des modèles avec des séquences très longues. Sa fonction principale est de gérer la consommation de mémoire des activations, qui augmente de manière quadratique avec la longueur de la séquence et peut dépasser la capacité HBM. CP fragmente la dimension de séquence des Tensors d'activation, ce qui permet d'utiliser une taille de lot fractionnaire par appareil. Étant donné que CP introduit plus de communication que FSDP, la règle générale consiste à utiliser le degré minimal de CP nécessaire pour respecter les contraintes de mémoire et s'assurer que le shard de l'axe de lot reste un entier.
Le tableau suivant mappe les types de charge de travail courants à la stratégie de partitionnement optimale :
| Type de charge de travail | Segmentation principale recommandée | Segmentation secondaire | Principaux goulots d'étranglement | Explication |
|---|---|---|---|---|
| Modèle dense : séquence courte | FSDP | N/A | Rematérialisation, FF Matmuls | Le FSDP offre le meilleur équilibre. Avec des séquences courtes, la mémoire d'activation n'est pas forcément un problème majeur. La clé est un lot global suffisamment grand pour masquer l'agrégation de poids FSDP. À mesure que la taille du lot augmente, la taille de l'activation augmente également. Une stratégie de rematérialisation appropriée est donc nécessaire pour s'assurer que cette configuration ne manque pas de mémoire. |
| Modèle dense : séquence longue | FSDP | CP | Attention flash, mémoire d'activation | La mémoire d'activation devient la contrainte principale. CP est nécessaire pour activer les tailles de lot fractionnaires par appareil et éviter les problèmes de mémoire insuffisante (OOM). L'attention flash est la principale source de calcul et de perte de temps. |
| Modèle MoE : séquence courte | FSDP + EP | N/A | All-to-All (routage expert), rematérialisation | Les modèles MoE nécessitent un EP pour fragmenter les experts. La communication All-to-All pour le routage des jetons est un goulot d'étranglement majeur qui doit être chevauché. La rematérialisation est également une source importante de déchets. |
| Modèle MoE : très grande échelle | FSDP + EP + PP | Parallélisme du modèle (MP) | Tous les goulots d'étranglement mentionnés précédemment, ainsi que les bulles de pipeline | Pour les modèles qui dépassent la mémoire d'un seul pod, le PP est nécessaire pour fragmenter les couches entre les pods. Cela introduit les frais généraux de communication et de pipeline DCN. Il s'agit d'une configuration très complexe qui nécessite un réglage minutieux. |
Optimisation de la communication
Le principal mécanisme de chevauchement de la communication et du calcul sur TPU7x est appelé SparseCore Collective Offloading. L'architecture Ironwood inclut des unités SparseCore dédiées, qui agissent comme des threads de contrôle indépendants capables de gérer le déplacement des données sur le réseau ICI. Cela permet aux opérations de communication collective (comme All-Gather ou Reduce-Scatter) de s'exécuter en parallèle avec les calculs principaux effectués sur les TensorCores. Il s'agit de la méthode recommandée pour les collectifs asynchrones sur TPU7x. Utilisez les flags recommandés pour activer le déchargement des collectifs les plus courants.
Rematérialisation de l'activation
La rematérialisation des activations, également appelée checkpointing de gradient, est une technique fondamentale pour réduire l'empreinte HBM d'un modèle. Au lieu de stocker toutes les activations intermédiaires de la propagation avant dans la mémoire HBM pour les utiliser lors de la rétropropagation, il n'enregistre que quelques activations clés (points de contrôle) et recalcule les autres à la demande lors de la rétropropagation. Cela permet d'économiser une quantité importante de mémoire au prix d'un calcul accru (environ 25 à 30 % de FLOP supplémentaires pour un bloc de transformateur standard).
La décision d'appliquer la rematérialisation de manière agressive est un paramètre de réglage essentiel qui dépend entièrement du principal goulot d'étranglement, qui varie souvent en fonction de la longueur de la séquence.
Pour les charges de travail à longue séquence (telles que 128 k) : dans ces cas, la taille des Tensors d'activation est le principal consommateur de HBM. La charge de travail est généralement liée à la mémoire. Par conséquent, il est très avantageux d'appliquer une stratégie de rematérialisation agressive. Les économies de mémoire permettent à l'entraînement de se poursuivre sans erreurs de mémoire insuffisante et autorisent également des tailles de lot plus importantes. Le coût de calcul de la réexécution est un compromis intéressant.
Pour les charges de travail à séquence courte (comme 8k) : dans ce cas, la mémoire d'activation est beaucoup moins préoccupante et la charge de travail est plus susceptible d'être liée au calcul. Les frais généraux de calcul de la rematérialisation peuvent être la principale source d'inefficacité.
Ajuster les règles de rematérialisation dans MaxText
MaxText offre un contrôle précis de la rematérialisation grâce à un ensemble de règles prédéfinies et personnalisées, configurées à l'aide de l'indicateur remat_policy.
Règles prédéfinies
MaxText propose les règles intégrées suivantes :
full: stratégie la plus agressive, qui rematérialise presque tout. Cela minimise l'utilisation de la HBM, mais maximise la surcharge de recalcul. (idéal pour les scénarios à longue séquence et à mémoire extrêmement limitée)minimal: stratégie la moins agressive, qui stocke la plupart des activations. Cela maximise l'utilisation de la HBM, mais minimise le recalcul. Idéal pour les charges de travail à séquence courte et liées au calcul, où la mémoire n'est pas un problème.- Stratégies intermédiaires : les options telles que
save_dot_with_context_except_mlp,save_qkv_projetsave_out_projoffrent différents compromis en enregistrant sélectivement les sorties des opérations de produit scalaire coûteuses tout en rematérialisant les opérations élément par élément moins coûteuses.
Règles personnalisées
Pour un contrôle plus précis, vous pouvez définir remat_policy sur custom. Cela vous permet de spécifier le comportement des couches individuelles dans le module de décodage du modèle. Chaque calque peut se voir attribuer l'un des trois comportements suivants :
device: l'activation est stockée dans la mémoire à haut débit de l'appareil TPU.remat: l'activation est supprimée et sera rematérialisée lors de la rétropropagation.offload: l'activation est déplacée de la mémoire HBM vers la mémoire de l'hôte du processeur, ce qui libère de la mémoire HBM au détriment de la latence de transfert PCIe.
Optimisation de VMEM à portée limitée
Les performances du noyau, comme l'attention flash, dépendent des tailles de tuiles sélectionnées dans le noyau, dont la taille est limitée par la mémoire vectorielle (VMEM) disponible. Les puces TPU7x disposent de 64 Mo de VMEM, qui peuvent être répartis entre la portée actuelle (VMEM à portée limitée) et la prélecture des futurs poids. L'augmentation de la VMEM à portée permet d'accroître la taille des tuiles dans le noyau, ce qui peut réduire les blocages de mémoire et améliorer les performances des noyaux. Vous pouvez modifier la taille de la VMEM à portée limitée en définissant xla_tpu_scoped_vmem_limit_kib (dans LIBTPU_INIT_ARGS), qui peut être utilisé pour explorer les performances du noyau ainsi que les limites de performances de bout en bout.
L'optimisation de la taille de la VMEM à portée limitée peut affecter indirectement les performances du noyau Pallas personnalisé, car l'augmentation de la VMEM à portée limitée débloque un espace de recherche d'hyperparamètres plus grand pour les tailles de blocs dans le noyau.
Noyaux Tokamax
Tokamax, une bibliothèque de kernels JAX hautes performances avec de nombreux kernels TPU hautement optimisés, résout plusieurs goulots d'étranglement matériels courants :
- Attention Splash : l'attention Splash est utilisée comme implémentation d'attention principale pour éliminer le goulot d'étranglement HBM de l'attention standard et utilise l'implémentation d'attention la plus efficace sur les TPU.
- Multiplication matricielle groupée Megablox (GMM) : pour les charges de travail MoE, Megablox gère efficacement les multiplications matricielles groupées en effectuant des calculs sur la représentation des activations irrégulières. Il mappe efficacement la dimension irrégulière, en calculant les multiplications de matrices entre les groupes irréguliers de lignes dans LHS et la matrice d'experts correspondante, ce qui évite d'avoir à remplir les lots à une taille fixe.
- Réglage empirique avec
tune-jax: la bibliothèquetune-jaxdispose d'utilitaires permettant d'effectuer des recherches empiriques pour trouver les tailles de bloc optimales. Les tailles de noyau par défaut sont souvent sous-optimales. Le réglage permet de choisir des tailles de blocs VMEM compatibles avec le matériel pour maximiser l'utilisation du matériel. - Estimation des logits max : le noyau d'attention Tokamax Splash peut être optimisé davantage en définissant une valeur pour
max_logit_const. Si cette valeur est définie, elle remplace le calcul de réduction du logit maximal lors de l'opération softmax de l'attention (softmax(Q * KT)), ce qui réduit une partie de la surcharge de calcul et de synchronisation. Dans MaxText, il est implémenté par la configurationuse_max_logits_estimate, qui peut être définie surNone(désactivé) ou sur une valeur à virgule flottante. Vérifiez que la plage de logits de votre modèle spécifique reste compatible avec l'estimation pour éviter le dépassement numérique. Il est recommandé de tester la convergence si cette valeur est définie.