Ottimizzazioni delle prestazioni di TPU7x (Ironwood)

Questa guida descrive diversi metodi per ottimizzare le prestazioni con TPU7x (Ironwood) gestendo in modo efficiente lo spostamento dei dati tra il sistema di memoria multilivello. Ciò include tecniche come l'addestramento a bassa precisione, lo sharding, l'ottimizzazione della comunicazione, la rematerializzazione dell'attivazione, la regolazione della memoria virtuale con ambito e i kernel dell'acceleratore personalizzati.

Per ottimizzare le prestazioni con TPU7x, devi prima acquisire familiarità con l'architettura Ironwood, in particolare con la gerarchia di memoria e la topologia di interconnessione. Per maggiori informazioni, vedi TPU7x (Ironwood).

Addestramento a bassa precisione con FP8

FP8 (virgola mobile a 8 bit) è un formato di dati numerici efficiente utilizzato principalmente per accelerare l'addestramento e l'inferenza dei modelli. Rappresentando i numeri utilizzando 8 bit anziché i formati standard a 16 bit (FP16 o BF16) e a 32 bit (FP32), le TPU possono elaborare i dati molto più velocemente e utilizzare meno memoria.

TPU7x supporta l'accelerazione hardware integrata per i tipi di dati FP8, offrendo una prestazione teorica di picco di 4614 TFLOPS per chip. Questa funzionalità può portare a tempi di addestramento end-to-end molto più rapidi. Per le operazioni compatibili, in particolare le moltiplicazioni di matrici dense comuni per i carichi di lavoro AI, l'utilizzo di FP8 può migliorare le prestazioni di 1,3 volte rispetto all'addestramento BF16 standard. Rispetto a BF16, FP8 raddoppia i FLOP di picco e dimezza l'impronta di memoria per pesi e attivazioni. FP8 deve essere un parametro di ottimizzazione principale sia per i carichi di lavoro vincolati dal calcolo sia per gli scenari vincolati dalla capacità di memoria o dalla larghezza di banda.

L'utilizzo di FP8 offre i seguenti vantaggi in termini di prestazioni:

  • Pressione ridotta sulla memoria ad alta larghezza di banda (HBM): un footprint di memoria più piccolo consente a modelli più grandi o con cache KV più grandi durante l'inferenza di rientrare completamente nei 192 GB di HBM. In questo modo si evita un costoso scaricamento nella memoria host più lenta.
  • Aumento delle dimensioni effettive dei batch: riducendo la memoria richiesta per le attivazioni, FP8 consente l'utilizzo di dimensioni dei batch più grandi. Ciò migliora il parallelismo dei dati e può portare a una velocità effettiva maggiore e a un migliore utilizzo delle unità di calcolo.
  • Requisiti di larghezza di banda della memoria inferiori: lo spostamento della metà dei dati per ogni operazione riduce la domanda sul percorso dei dati da HBM a MXU. Sui sistemi in cui lo spostamento dei dati è un collo di bottiglia comune, questo aiuta a mantenere le MXU sature di lavoro.

L'utilizzo di FP8 con degradazione delle prestazioni pari a zero o limitata richiede un'attenta selezione delle tecniche di quantizzazione. Ecco alcune best practice da considerare per l'addestramento FP8:

  • Granularità dello scaling: inizia con lo scaling per tensore come base di riferimento. Se si verificano problemi di qualità o rendimento, passa al ridimensionamento per asse. Il ridimensionamento dei sottocanali potrebbe non essere necessario.
  • Modalità di scalabilità: la scalabilità dinamica, che calcola i fattori di scalabilità in fase di runtime, è un buon valore predefinito per mantenere la qualità. Sebbene lo scaling statico possa offrire un aumento significativo delle prestazioni eliminando i calcoli, richiede un'attenta profilazione per determinare i fattori di scalabilità corretti e potrebbe non essere adatto a tutti i casi d'uso, soprattutto quando le configurazioni del modello cambiano. Al contrario, alcuni modelli e configurazioni robusti possono correggere la scala al limite FP8 per pesi o attivazioni, consentendoti di ridurre l'overhead di quantizzazione mantenendo la precisione e migliorando le prestazioni.
  • Formati FP8 (E4M3 ed E5M2): un approccio comune ed efficace è utilizzare un mix di formati FP8. Ad esempio, utilizza E4M3 per pesi e attivazioni nel passaggio in avanti per sfruttare la maggiore precisione di E4M3 e utilizza E5M2 per i gradienti nel passaggio all'indietro per adattarsi alla gamma dinamica più ampia dei gradienti.
  • Arrotondamento: l'utilizzo dell'arrotondamento al numero pari più vicino (RNE) anziché dell'arrotondamento stocastico per i gradienti può mantenere la qualità offrendo al contempo prestazioni e riproducibilità migliori.
  • Abilitazione di FP8 in MaxText: MaxText supporta l'addestramento FP8 tramite la libreria di quantizzazione QWIX. Per attivare la quantizzazione, imposta il seguente flag nella configurazione: use_qwix_quantization=true.

Sharding e parallelismo

Lo sharding è il processo di suddivisione di un modello di grandi dimensioni o dei relativi dati di addestramento in parti più piccole e la loro distribuzione su più chip o core TPU. La scelta della strategia di sharding giusta è importante per ottenere prestazioni elevate su TPU7x.

Un approccio ingenuo che massimizza puramente il grado di parallelismo spesso comporta prestazioni scarse diventando vincolato alla comunicazione. L'approccio migliore è spesso quello di selezionare la strategia di sharding più semplice che soddisfi i vincoli di memoria, in quanto ciò riduce al minimo l'overhead di comunicazione e consente di utilizzare in modo efficiente le unità di calcolo.

Prima di selezionare una strategia di sharding, il primo passaggio di qualsiasi ottimizzazione delle prestazioni deve essere un'analisi dell'intensità aritmetica. Questa analisi determina se un determinato calcolo è limitato dalla capacità di calcolo, dalla larghezza di banda della memoria o dalla larghezza di banda dell'interconnessione. Viene calcolato come il rapporto tra le operazioni in virgola mobile e i byte di dati da spostare.

Un'elevata intensità aritmetica indica un workload legato al calcolo. Una bassa intensità aritmetica suggerisce un carico di lavoro vincolato alla memoria o alla comunicazione, in cui le prestazioni sono limitate dalla velocità con cui i dati possono essere spostati dalla HBM o nella rete ICI. Questa analisi fornisce informazioni sulla dimensione batch e sulla strategia di partizionamento ideali. Ad esempio, un carico di lavoro vincolato alla comunicazione non trarrà vantaggio da una strategia di sharding che introduce ancora più comunicazione, come il parallelismo dei tensori di alto grado.

Framework decisionale per la strategia di partizionamento

MaxText offre una serie di strategie di sharding. La scelta ottimale dipende dall'architettura del modello, dalla lunghezza della sequenza e dalla necessità di bilanciare il carico di calcolo rispetto all'overhead di comunicazione.

  • Parallelismo dei dati completamente partizionati (FSDP): questa è la strategia predefinita preferita per il parallelismo dei dati. FSDP suddivide i pesi, i gradienti e gli stati dell'ottimizzatore del modello tra i dispositivi con parallelismo dei dati. Durante il calcolo, ogni dispositivo esegue un'operazione All-Gather per recuperare i pesi completi necessari per il microbatch locale. FSDP è molto efficace se la dimensione del batch per dispositivo è sufficientemente grande da nascondere la latenza di questa comunicazione All-Gather. Per i modelli Mixture-of-Experts (MoE), il calcolo dell'intensità aritmetica deve tenere conto della scarsità.
  • Parallelismo dei tensori (TP): TP suddivide i singoli tensori tra i dispositivi. In genere, i tensori sono matrici di peso nel percettrone multistrato (MLP) e nei blocchi di attenzione. L'elevata intensità aritmetica dell'hardware (11, 5 k) impone un requisito molto elevato per le dimensioni del modello per rendere praticabile TP su ICI e il tentativo di utilizzare TP può comportare un sistema vincolato dalla comunicazione.
  • Parallelismo degli esperti (EP): questa è la strategia standard e necessaria per l'addestramento dei modelli MoE. EP suddivide i livelli "esperti" su un insieme di dispositivi e viene utilizzato un collettivo di comunicazione All-to-All per indirizzare i token al dispositivo esperto designato. EP può essere efficiente se la dimensione MLP del modello è sufficientemente grande da avvicinarsi alla roofline.
  • Parallelismo contestuale (CP): CP è una strategia specializzata essenziale per l'addestramento di modelli con sequenze molto lunghe. La sua funzione principale è gestire il consumo di memoria delle attivazioni, che aumenta in modo quadratico con la lunghezza della sequenza e può superare la capacità HBM. CP suddivide la dimensione della sequenza dei tensori di attivazione, il che consente l'utilizzo di una dimensione del batch frazionaria per dispositivo. Poiché CP introduce più comunicazione rispetto a FSDP, la regola generale è di utilizzare il grado minimo di CP necessario per soddisfare i vincoli di memoria e garantire che lo shard dell'asse batch rimanga un numero intero.

La tabella seguente mappa i tipi di carichi di lavoro comuni alla strategia di sharding ottimale:

Tipo di workload Sharding primario consigliato Sharding secondario Colli di bottiglia principali Rationale
Modello denso - sequenza breve FSDP N/D Rematerialization, FF Matmuls FSDP offre il miglior equilibrio. Con sequenze brevi, la memoria di attivazione potrebbe non essere un problema importante. La chiave è un batch globale sufficientemente grande per nascondere il peso di All-Gather di FSDP. Man mano che le dimensioni del batch aumentano, le dimensioni dell'attivazione aumentano ed è necessaria una politica di materializzazione adatta per garantire che questa configurazione non esaurisca la memoria.
Modello denso - sequenza lunga FSDP CP Attenzione flash, memoria di attivazione La memoria di attivazione diventa il vincolo principale. CP è necessario per attivare dimensioni del batch frazionarie per dispositivo ed evitare problemi di esaurimento della memoria (OOM). L'attenzione flash è la principale fonte di calcolo e tempo sprecato.
Modello MoE - sequenza breve FSDP + EP N/D All-to-All (routing esperto), rematerializzazione I modelli MoE richiedono EP per partizionare gli esperti. La comunicazione All-to-All per il routing dei token è un collo di bottiglia importante che deve essere sovrapposto. Anche la riproduzione è una fonte significativa di rifiuti.
Modello MoE - su larga scala FSDP + EP + PP Parallelismo dei modelli (MP) Tutti i colli di bottiglia menzionati in precedenza, più le bolle della pipeline Per i modelli che superano la memoria di un singolo pod, è necessario PP per dividere i livelli tra i pod. Vengono introdotti gli overhead di comunicazione e pipeline della DCN. Si tratta di una configurazione molto complessa che richiede un'attenta ottimizzazione.

Ottimizzazione della comunicazione

Il meccanismo principale per la comunicazione e il calcolo sovrapposti sulla TPU7x è chiamato SparseCore Collective Offloading. L'architettura Ironwood include unità SparseCore dedicate, che fungono da thread di controllo indipendenti in grado di gestire il movimento dei dati sulla struttura ICI. Ciò consente alle operazioni di comunicazione collettiva (come All-Gather o Reduce-Scatter) di essere eseguite in parallelo con i calcoli principali eseguiti sui Tensor Core. Questo è il metodo consigliato per le primitive collettive asincrone su TPU7x. Utilizza i flag consigliati per attivare l'offload per i collettivi più comuni.

Riematerializzazione dell'attivazione

La rematerializzazione dell'attivazione, nota anche come checkpointing del gradiente, è una tecnica fondamentale per ridurre l'impronta HBM di un modello. Anziché memorizzare tutte le attivazioni intermedie del passaggio in avanti in HBM da utilizzare durante il passaggio all'indietro, salva solo alcune attivazioni chiave (checkpoint) e ricalcola le altre su richiesta durante il passaggio all'indietro. In questo modo si risparmia una quantità significativa di memoria a costo di un aumento del calcolo (circa il 25-30% di FLOP aggiuntivi per un blocco Transformer standard).

La decisione di quanto aggressivamente applicare la materializzazione è un parametro di ottimizzazione critico che dipende interamente dal collo di bottiglia principale, che spesso varia in base alla lunghezza della sequenza.

Per i workload a sequenza lunga (ad esempio 128.000): in questi casi, le dimensioni dei tensori di attivazione sono il principale consumatore di HBM. Il carico di lavoro è in genere vincolato alla memoria. Pertanto, l'applicazione di una politica di rematerializzazione aggressiva è altamente vantaggiosa. Il risparmio di memoria consente di procedere con l'addestramento senza errori di memoria insufficiente e consente anche dimensioni del batch maggiori. L'overhead di calcolo della ricomputazione è un compromesso valido.

Per i carichi di lavoro a sequenza breve (ad esempio 8k): in questi casi, la memoria di attivazione è molto meno un problema e il carico di lavoro è più probabilmente vincolato al calcolo. L'overhead di calcolo della rematerializzazione può essere la principale fonte di inefficienza.

Ottimizzazione delle policy di materializzazione in MaxText

MaxText offre un controllo granulare della rematerializzazione tramite un insieme di policy preimpostate e personalizzate, configurate utilizzando il flag remat_policy.

Criteri preimpostati

MaxText offre le seguenti policy integrate:

  • full: la policy più aggressiva, che ricrea quasi tutto. In questo modo, l'utilizzo della HBM viene ridotto al minimo, ma il sovraccarico di ricalcolo viene massimizzato. Ideale per scenari con sequenze lunghe e vincoli di memoria estremamente elevati.
  • minimal: la policy meno aggressiva, che memorizza la maggior parte delle attivazioni. In questo modo l'utilizzo di HBM viene massimizzato, mentre il ricalcolo viene ridotto al minimo. Ideale per sequenze brevi, workload vincolati al calcolo in cui la memoria non è un problema.
  • Norme intermedie: opzioni come save_dot_with_context_except_mlp, save_qkv_proj e save_out_proj offrono vari compromessi controllando selettivamente gli output di operazioni di prodotto scalare costose e rimaterializzando operazioni più economiche elemento per elemento.

Norme personalizzate

Per un maggiore livello di controllo, puoi impostare remat_policy su custom. In questo modo puoi specificare il comportamento dei singoli livelli all'interno del modulo di decodifica del modello. A ogni livello può essere assegnato uno dei tre comportamenti seguenti:

  • device: l'attivazione è memorizzata nella HBM sul dispositivo TPU.
  • remat: l'attivazione viene eliminata e verrà materializzata nuovamente durante il passaggio all'indietro.
  • offload: l'attivazione viene spostata dalla HBM alla memoria dell'host CPU, liberando la HBM a scapito della latenza di trasferimento PCIe.

Ottimizzazione di VMEM con ambito

Le prestazioni del kernel, come l'attenzione flash, dipendono dalle dimensioni dei riquadri selezionati nel kernel, le cui dimensioni sono limitate dalla memoria vettoriale (VMEM) disponibile. I chip TPU7x hanno 64 MB di VMEM, che possono essere suddivisi tra l'ambito corrente (VMEM con ambito) e il precaricamento futuro dei pesi. L'aumento della VMEM con ambito consente di aumentare le dimensioni dei riquadri nel kernel, riducendo potenzialmente gli stalli della memoria e aumentando le prestazioni dei kernel. Puoi modificare le dimensioni di VMEM con ambito impostando xla_tpu_scoped_vmem_limit_kib (in LIBTPU_INIT_ARGS), che può essere utilizzato per esplorare le prestazioni del kernel e i limiti di rendimento end-to-end. L'ottimizzazione delle dimensioni della VMEM con ambito può influire indirettamente sulle prestazioni del kernel Pallas personalizzato, poiché l'aumento della VMEM con ambito sblocca uno spazio di ricerca degli iperparametri più ampio per le dimensioni dei riquadri nel kernel.

Kernel Tokamax

Tokamax, una libreria di kernel JAX ad alte prestazioni con molti kernel TPU altamente ottimizzati, risolve diversi colli di bottiglia comuni specifici dell'hardware:

  • Splash attention: Splash attention viene utilizzato come implementazione principale dell'attenzione per eliminare il collo di bottiglia HBM dell'attenzione standard e utilizza l'implementazione dell'attenzione più efficiente sulle TPU.
  • Moltiplicazione di matrici raggruppate (GMM) Megablox: per i carichi di lavoro MoE, Megablox gestisce in modo efficiente le moltiplicazioni di matrici raggruppate eseguendo i calcoli sulla rappresentazione delle attivazioni irregolari. Mappa in modo efficiente la dimensione irregolare, calcolando le moltiplicazioni di matrici tra gruppi irregolari di righe nel lato sinistro e la matrice esperta corrispondente, evitando la necessità di riempire i batch con una dimensione fissa.
  • Ottimizzazione empirica con tune-jax: la libreria tune-jax dispone di utilità per eseguire ricerche empiriche delle dimensioni ottimali dei blocchi. Le dimensioni predefinite del kernel sono spesso non ottimali; la regolazione consente di scegliere dimensioni dei riquadri VMEM compatibili con l'hardware per massimizzare l'utilizzo dell'hardware.
  • Stima logit massima: il kernel di attenzione Tokamax Splash può essere ulteriormente ottimizzato impostando un valore per max_logit_const. Se impostato, sostituisce il calcolo della riduzione del logit massimo durante l'operazione softmax di attenzione (softmax(Q * KT)), riducendo alcuni sovraccarichi di calcolo e sincronizzazione. In MaxText, viene implementato dalla configurazione use_max_logits_estimate, che può essere impostata su None (disabilitato) o su un valore in virgola mobile. Verifica che l'intervallo logit del tuo modello specifico rimanga compatibile con la stima per evitare l'overflow numerico. Se questo valore è impostato, è consigliabile eseguire il test di convergenza.