Ottimizzazioni delle prestazioni di TPU7x (Ironwood)
Questa guida descrive diversi metodi per ottimizzare le prestazioni con TPU7x (Ironwood) gestendo in modo efficiente lo spostamento dei dati tra il sistema di memoria a più livelli. Sono incluse tecniche come l'addestramento a bassa precisione, lo sharding, l'ottimizzazione della comunicazione, la rimaterializzazione dell'attivazione, l'ottimizzazione della memoria virtuale con ambito e i kernel degli acceleratori personalizzati.
Per ottimizzare le prestazioni con TPU7x, devi prima acquisire familiarità con l'architettura Ironwood, in particolare con la gerarchia della memoria e la topologia di interconnessione. Per ulteriori informazioni, consulta TPU7x (Ironwood).
Addestramento a bassa precisione con FP8
FP8 (floating point a 8 bit) è un formato di dati numerici efficiente utilizzato principalmente per accelerare l'addestramento e l'inferenza dei modelli. Rappresentando i numeri utilizzando 8 bit anziché i formati standard a 16 bit (FP16 o BF16) e a 32 bit (FP32), le TPU possono elaborare i dati in modo significativamente più rapido e utilizzare meno memoria.
TPU7x supporta l'accelerazione hardware integrata per i tipi di dati FP8, offrendo una prestazione teorica di picco di 4614 TFLOPS per chip. Questa funzionalità può ridurre in modo significativo i tempi di addestramento end-to-end. Per le operazioni compatibili, in particolare le moltiplicazioni di matrici dense comuni per i carichi di lavoro AI, l'utilizzo di FP8 può comportare miglioramenti delle prestazioni di 1,3 volte rispetto all'addestramento BF16 standard. Rispetto a BF16, FP8 raddoppia i FLOP di picco e dimezza il footprint della memoria per pesi e attivazioni. FP8 dovrebbe essere un leva di ottimizzazione principale sia per i carichi di lavoro con vincoli di calcolo sia per gli scenari vincolati dalla capacità o dalla larghezza di banda della memoria.
L'utilizzo di FP8 offre i seguenti vantaggi in termini di prestazioni:
- Pressione ridotta della memoria HBM (High-Bandwidth Memory): un footprint della memoria inferiore consente di adattare modelli più grandi o modelli con cache KV più grandi durante l'inferenza all'interno dei 192 GB di HBM. In questo modo si evita il costoso scaricamento nella memoria host più lenta.
- Dimensioni del batch effettive maggiori: riducendo la memoria richiesta per le attivazioni, FP8 consente l'utilizzo di dimensioni del batch maggiori. In questo modo si migliora il parallelismo dei dati e si può ottenere una velocità effettiva maggiore e un utilizzo migliore delle unità di calcolo.
- Requisiti di larghezza di banda della memoria inferiori: lo spostamento della metà della quantità di dati per ogni operazione riduce la domanda sul percorso dati HBM-MXU. Nei sistemi in cui lo spostamento dei dati è un collo di bottiglia comune, questo aiuta a mantenere le MXU sature di lavoro.
L'utilizzo di FP8 con una degradazione delle prestazioni pari a zero o limitata richiede un'attenta selezione delle tecniche di quantizzazione. Ecco alcune best practice da tenere presenti per l'addestramento FP8:
- Granularità di scalabilità: inizia con la scalabilità per tensore come baseline. Se si verificano problemi di qualità o prestazioni, passa alla scalabilità per asse. La scalabilità dei sottocanali potrebbe non essere necessaria.
- Modalità di scalabilità: la scalabilità dinamica, che calcola i fattori di scalabilità in fase di runtime, è una buona impostazione predefinita per mantenere la qualità. Sebbene la scalabilità statica possa offrire un aumento significativo delle prestazioni eliminando i calcoli, richiede un'attenta profilazione per determinare i fattori di scalabilità corretti e potrebbe non essere adatta a tutti i casi d'uso, soprattutto quando le configurazioni dei modelli cambiano. Al contrario, alcuni modelli e configurazioni robusti possono correggere la scalabilità al limite FP8 per pesi o attivazioni, consentendoti di ridurre l'overhead di quantizzazione mantenendo la precisione e migliorando le prestazioni.
- Formati FP8 (E4M3 ed E5M2): un approccio comune ed efficace consiste nell'utilizzare un mix di formati FP8. Ad esempio, utilizza E4M3 per pesi e attivazioni nel passaggio in avanti per sfruttare la maggiore precisione di E4M3 e utilizza E5M2 per i gradienti nel passaggio all'indietro per adattarsi alla gamma dinamica più ampia dei gradienti.
- Arrotondamento: l'utilizzo di "arrotondamento al numero pari più vicino" (RNE) anziché l'arrotondamento stocastico per i gradienti può mantenere la qualità offrendo al contempo prestazioni e riproducibilità migliori.
- Abilitazione di FP8 in MaxText:
MaxText supporta l'addestramento FP8
tramite la libreria di quantizzazione QWIX. Per attivare la quantizzazione, imposta il seguente flag nella configurazione:
use_qwix_quantization=true.
Sharding e parallelismo
Lo sharding è il processo di suddivisione di un modello di grandi dimensioni o dei relativi dati di addestramento in parti più piccole e della loro distribuzione su più chip o core TPU. La scelta della strategia di sharding corretta è importante per ottenere prestazioni elevate su TPU7x.
Un approccio ingenuo che massimizza puramente il grado di parallelismo spesso comporta prestazioni scadenti diventando vincolato alla comunicazione. L'approccio migliore consiste spesso nel selezionare la strategia di sharding più semplice che soddisfi i vincoli di memoria, in quanto riduce al minimo l'overhead di comunicazione e consente di utilizzare in modo efficiente le unità di calcolo.
Prima di selezionare una strategia di sharding, il primo passaggio di qualsiasi tentativo di ottimizzazione delle prestazioni dovrebbe essere un'analisi dell'intensità aritmetica. Questa analisi determina se un determinato calcolo è limitato dal calcolo, dalla larghezza di banda della memoria o dalla larghezza di banda di interconnessione. Viene calcolata come il rapporto tra le operazioni in rappresentazione in virgola mobile e i byte di dati che devono essere spostati.
Un'elevata intensità aritmetica indica un carico di lavoro con vincoli di calcolo. Una bassa intensità aritmetica suggerisce un carico di lavoro con vincoli di memoria o comunicazione, in cui le prestazioni sono limitate dalla velocità con cui i dati possono essere spostati da HBM o nella rete ICI. Questa analisi fornisce informazioni sulla dimensione del batch e sulla strategia di sharding ideali. Ad esempio, un carico di lavoro con vincoli di comunicazione non trarrà vantaggio da una strategia di sharding che introduce ancora più comunicazione, come il parallelismo dei tensori di alto grado.
Framework decisionale per la strategia di sharding
MaxText offre una serie di strategie di sharding. La scelta ottimale dipende dall'architettura del modello, dalla lunghezza della sequenza e dalla necessità di bilanciare il carico computazionale con l'overhead di comunicazione.
- Parallelismo dei dati completamente suddivisi (FSDP): questa è la strategia predefinita preferita per il parallelismo dei dati. FSDP suddivide i pesi del modello, i gradienti e gli stati dell'ottimizzatore tra i dispositivi paralleli ai dati. Durante il calcolo, ogni dispositivo esegue un'operazione All-Gather per recuperare i pesi completi necessari per il microbatch locale. FSDP è molto efficace a condizione che la dimensione del batch per dispositivo sia sufficientemente grande da nascondere la latenza di questa comunicazione All-Gather. Per i modelli Mixture-of-Experts (MoE), il calcolo dell'intensità aritmetica deve tenere conto della scarsità.
- Parallelismo dei tensori (TP): TP suddivide i singoli tensori tra i dispositivi. In genere, i tensori sono matrici di pesi nei blocchi di attenzione e di percettrone multistrato (MLP). L'elevata intensità aritmetica dell'hardware (11, 5 k) impone un requisito molto elevato sulle dimensioni del modello per rendere TP praticabile su ICI e il tentativo di utilizzare TP può comportare che il sistema sia vincolato alla comunicazione.
- Parallelismo degli esperti (EP): questa è la strategia standard e necessaria per l'addestramento dei modelli MoE. EP suddivide i livelli "esperti" su un insieme di dispositivi e viene utilizzato un collettivo di comunicazione All-to-All per instradare i token al dispositivo esperto designato. EP può essere efficiente se la dimensione MLP del modello è sufficientemente grande da avvicinarsi al limite massimo.
- Parallelismo del contesto (CP): CP è una strategia specializzata essenziale per l'addestramento di modelli con lunghezze di sequenza molto lunghe. La sua funzione principale è gestire il consumo di memoria delle attivazioni, che cresce in modo quadratico con la lunghezza della sequenza e può superare la capacità HBM. CP suddivide la dimensione della sequenza dei tensori di attivazione, il che consente l'utilizzo di una dimensione del batch frazionaria per dispositivo. Poiché CP introduce più comunicazione rispetto a FSDP, la regola generale è di utilizzare il grado minimo di CP necessario per soddisfare i vincoli di memoria e garantire che la suddivisione dell'asse del batch rimanga un numero intero.
La tabella seguente mappa i tipi di carico di lavoro comuni alla strategia di sharding ottimale:
| Tipo di carico di lavoro | Sharding primario consigliato | Sharding secondario | Colli di bottiglia principali | Rationale |
|---|---|---|---|---|
| Modello denso - sequenza breve | FSDP | N/D | Rimaterializzazione, FF Matmuls | FSDP offre il miglior equilibrio. Con sequenze brevi, la memoria di attivazione potrebbe non essere un problema importante. La chiave è un batch globale sufficientemente grande da nascondere l'All-Gather dei pesi di FSDP. Man mano che la dimensione del batch aumenta, aumenta anche la dimensione dell'attivazione ed è necessaria una policy di rimaterializzazione adatta per garantire che questa configurazione non esaurisca la memoria. |
| Modello denso - sequenza lunga | FSDP | CP | Attenzione flash, memoria di attivazione | La memoria di attivazione diventa il vincolo principale. CP è necessario per abilitare le dimensioni del batch frazionarie per dispositivo ed evitare problemi di esaurimento della memoria (OOM) . L'attenzione flash è la fonte dominante di calcolo e tempo perso. |
| Modello MoE - sequenza breve | FSDP + EP | N/D | All-to-All (instradamento degli esperti), rimaterializzazione | I modelli MoE richiedono EP per suddividere gli esperti. La comunicazione All-to-All per l'instradamento dei token è un collo di bottiglia importante che deve essere sovrapposto. Anche la rimaterializzazione è una fonte significativa di spreco. |
| Modello MoE - scala molto grande | FSDP + EP + PP | Parallelismo del modello (MP) | Tutti i colli di bottiglia menzionati in precedenza, oltre alle bolle della pipeline | Per i modelli che superano la memoria di un singolo pod, è necessario PP per suddividere i livelli tra i pod. In questo modo vengono introdotti overhead di comunicazione DCN e bolle della pipeline. Si tratta di una configurazione molto complessa che richiede un'attenta ottimizzazione. |
Ottimizzazione della comunicazione
Il meccanismo principale per sovrapporre comunicazione e calcolo su TPU7x è chiamato SparseCore Collective Offloading. L'architettura Ironwood include unità SparseCore dedicate, che fungono da thread di controllo indipendenti in grado di gestire lo spostamento dei dati sulla struttura ICI. In questo modo, le operazioni di comunicazione collettiva (come All-Gather o Reduce-Scatter) possono essere eseguite in parallelo con i calcoli principali che avvengono sui TensorCore. Questo è il metodo consigliato per i collettivi asincroni su TPU7x. Utilizza i consigliati flag per abilitare l'offload per i collettivi più comuni.
Rimaterializzazione dell'attivazione
La rimaterializzazione dell'attivazione, nota anche come checkpointing dei gradienti, è una tecnica fondamentale per ridurre l'utilizzo di HBM di un modello. Anziché archiviare tutte le attivazioni intermedie del passaggio in avanti in HBM da utilizzare durante il passaggio all'indietro, salva solo alcune attivazioni chiave (checkpoint) e ricalcola le altre su richiesta durante il passaggio all'indietro. In questo modo si risparmia una quantità significativa di memoria a costo di un aumento del calcolo (circa il 25-30% di FLOP aggiuntivi per un blocco di trasformatori standard).
La decisione di quanto aggressivamente applicare la rimaterializzazione è un parametro di ottimizzazione fondamentale che dipende interamente dal collo di bottiglia principale, che spesso varia in base alla lunghezza della sequenza.
Per i carichi di lavoro con sequenze lunghe (ad esempio 128k): in questi casi, la dimensione dei tensori di attivazione è il consumatore dominante di HBM. Il carico di lavoro è in genere vincolato alla memoria. Pertanto, l'applicazione di una policy di rimaterializzazione aggressiva è molto vantaggiosa. Il risparmio di memoria consente di procedere con l'addestramento senza errori di esaurimento della memoria e consente anche dimensioni del batch maggiori, mentre l'overhead computazionale del ricalcolo è un compromesso valido.
Per i carichi di lavoro con sequenze brevi (ad esempio 8k): in questi casi, la memoria di attivazione è molto meno un problema e il carico di lavoro è più probabile che sia vincolato al calcolo. L'overhead computazionale della rimaterializzazione può essere la singola fonte di inefficienza più grande.
Ottimizzazione delle policy di rimaterializzazione in MaxText
MaxText fornisce un controllo granulare sulla rimaterializzazione tramite un insieme di policy preimpostate e personalizzate, configurate utilizzando il flag remat_policy.
Policy preimpostate
MaxText offre le seguenti policy integrate:
full: la policy più aggressiva, che rimaterializza quasi tutto. In questo modo si riduce al minimo l'utilizzo di HBM, ma si massimizza l'overhead di ricalcolo. Ideale per scenari con sequenze lunghe e con vincoli di memoria estremi.minimal: la policy meno aggressiva, che archivia la maggior parte delle attivazioni. In questo modo si massimizza l'utilizzo di HBM, ma si riduce al minimo il ricalcolo. Ideale per carichi di lavoro con sequenze brevi e con vincoli di calcolo in cui la memoria non è un problema.- Policy intermedie: opzioni come
save_dot_with_context_except_mlp,save_qkv_projesave_out_projforniscono vari compromessi eseguendo il checkpointing selettivo degli output di operazioni di prodotto scalare costose durante la rimaterializzazione di operazioni elementari più economiche.
Policy personalizzate
Per un maggiore livello di controllo, puoi impostare remat_policy su custom. In questo modo puoi specificare il comportamento per i singoli livelli all'interno del modulo di decodifica del modello. A ogni livello può essere assegnato uno dei tre comportamenti:
device: l'attivazione viene archiviata in HBM sul dispositivo TPU.remat: l'attivazione viene eliminata e verrà rimaterializzata durante il passaggio all'indietro.offload: l'attivazione viene spostata da HBM alla memoria dell'host della CPU, liberando HBM a costo della latenza di trasferimento PCIe.
Ottimizzazione di VMEM con ambito
Le prestazioni del kernel, come l'attenzione flash, dipendono dalle dimensioni dei riquadri selezionate nel kernel, la cui dimensione è limitata dalla memoria vettoriale (VMEM) disponibile. Ognuno dei due TensorCore in un chip TPU7x ha 64 MiB di memoria vettoriale (VMEM). Questa capacità VMEM può essere suddivisa tra l'ambito corrente (VMEM con ambito) e il precaricamento dei pesi futuri. L'aumento di VMEM con ambito consente di aumentare le dimensioni dei riquadri nel kernel, riducendo potenzialmente gli stalli di memoria e aumentando le prestazioni dei kernel. Puoi modificare la dimensione di VMEM con ambito impostando xla_tpu_scoped_vmem_limit_kib (in LIBTPU_INIT_ARGS), che può essere utilizzata per esplorare le prestazioni del kernel e i limiti di prestazioni end-to-end. L'ottimizzazione della dimensione di VMEM con ambito può influire indirettamente sulle prestazioni del kernel Pallas personalizzato, poiché l'aumento di VMEM con ambito sblocca uno spazio di ricerca degli iperparametri più ampio per le dimensioni dei riquadri in-kernel.
Kernel Tokamax
Tokamax, una libreria di kernel JAX ad alte prestazioni con molti kernel TPU altamente ottimizzati, risolve diversi colli di bottiglia comuni specifici dell'hardware:
- Attenzione splash: l'attenzione splash viene utilizzata come implementazione dell'attenzione principale per eliminare il collo di bottiglia HBM dell'attenzione standard e utilizza l'implementazione dell'attenzione più efficiente sulle TPU.
- Moltiplicazione di matrici raggruppate Megablox (GMM): per i carichi di lavoro MoE, Megablox gestisce in modo efficiente le moltiplicazioni di matrici raggruppate eseguendo il calcolo sulla rappresentazione delle attivazioni irregolari. Esegue in modo efficiente il mapping sulla dimensione irregolare, calcolando le moltiplicazioni di matrici tra gruppi di righe irregolari in LHS e la matrice esperta corrispondente, evitando la necessità di aggiungere batch a una dimensione fissa.
- Ottimizzazione empirica con
tune-jax: la libreriatune-jaxinclude utilità per eseguire ricerche empiriche delle dimensioni dei blocchi ottimali. Le dimensioni predefinite del kernel sono spesso non ottimali; l'ottimizzazione consente di scegliere dimensioni dei riquadri VMEM compatibili con l'hardware per massimizzare l'utilizzo dell'hardware. - Stima dei logit massimi: il kernel di attenzione splash Tokamax può essere ulteriormente
ottimizzato impostando un valore per
max_logit_const. Se impostato, sostituisce il calcolo di riduzione del logit massimo durante l'operazione softmax dell'attenzione (softmax(Q * KT)), riducendo alcuni overhead di calcolo e sincronizzazione. In MaxText, viene implementato dalla configurazioneuse_max_logits_estimate, che può essere impostata suNone(disabilitata) o su un valore in virgola mobile. Verifica che l'intervallo di logit del tuo modello specifico rimanga compatibile con la stima per evitare l'overflow numerico. Se questo valore è impostato, è consigliabile eseguire test di convergenza.