Scalare un modello sulle TPU

Questo documento fornisce una panoramica su come scalare i modelli linguistici: come funzionano le TPU e come comunicano tra loro, come vengono eseguiti i modelli LLM su hardware reale e come effettuare il parallelismo tra i modelli durante l'addestramento e l'inferenza in modo che vengano eseguiti in modo efficiente su larga scala. Forniamo informazioni che ti aiutano a valutare quanto sarà costoso addestrare un modello LLM, quanta memoria ti serve per pubblicare il modello e come eseguire lo sharding dei modelli in modo efficace su più TPU.

Anche se gran parte del deep learning è complessa, l'ottimizzazione delle prestazioni dei modelli non deve essere così, nemmeno su larga scala. I principi fondamentali si applicano ovunque, dalla gestione di un singolo acceleratore a decine di migliaia, e la loro comprensione ti consente di fare molte cose utili:

  • Stima quanto le parti del modello si avvicinano al loro valore ottimale teorico.
  • Fai scelte informate su diversi schemi di parallelismo a scale diverse (come dividere il calcolo su più dispositivi).
  • Stima il costo e il tempo necessari per addestrare ed eseguire modelli Transformer di grandi dimensioni.
  • Progetta algoritmi che sfruttano l'architettura TPU.
  • Progetta architetture di modelli basate su una comprensione esplicita di ciò che limita le prestazioni dell'algoritmo.

Prerequisiti

Dovresti avere una conoscenza di base dei modelli LLM e dell'architettura Transformer, ma non necessariamente di come funzionano su larga scala. Dovresti comprendere le nozioni di base dell'addestramento dei modelli LLM e, idealmente, avere una certa familiarità con JAX. Di seguito sono riportati alcuni documenti di riferimento utili per l'architettura Transformer:

Dopo aver acquisito familiarità con questi prerequisiti, dovresti sentirti a tuo agio a stimare lo schema di parallelismo migliore per un modello Transformer su una determinata piattaforma TPU. Potrai anche stimare la durata dell'addestramento e dell'inferenza.

Importanza dello scaling dei modelli

Oggi i modelli LLM e la maggior parte dei modelli di piccole dimensioni vengono eseguiti così vicino ai limiti hardware che lo sviluppo dei modelli richiede di pensare all'efficienza su larga scala. Un miglioramento del 20% nei benchmark è irrilevante se comporta un costo del 20% per l'efficienza del roofline. Le architetture di modelli promettenti falliscono regolarmente perché non possono essere eseguite in modo efficiente su larga scala o perché non sono stati compiuti sforzi di ottimizzazione per renderle tali.

L'obiettivo dello scaling dei modelli è quello di poter aumentare il numero di chip utilizzati per l'addestramento o l'inferenza ottenendo un aumento proporzionale e lineare della velocità effettiva. Questo è noto come scaling forte. Sebbene l'aggiunta di chip aggiuntivi (parallelismo) di solito riduca il tempo di calcolo, comporta anche un costo di comunicazione aggiuntiva tra i chip. Quando la comunicazione richiede più tempo del calcolo, il modello diventa vincolato alla comunicazione e non può essere scalato bene. Comprendere l'hardware abbastanza bene da prevedere dove si presenteranno questi colli di bottiglia ti consente di progettare o riconfigurare i modelli per evitarli.

Le sezioni seguenti forniscono una panoramica su come scalare l'hardware TPU e su come si è evoluta l'architettura Transformer. Queste informazioni sono utili sia per i ricercatori che progettano nuove architetture sia per gli ingegneri che lavorano per rendere veloce l'esecuzione della generazione attuale di modelli LLM.

Parte 1: concetti

Questa parte spiega l'analisi del roofline e i fattori che limitano la capacità di un modello di scalare (comunicazione, calcolo e memoria). Ora descriviamo come funzionano le TPU, sia come singoli chip sia, cosa di fondamentale importanza, come sistema interconnesso con link inter-chip di larghezza di banda e latenza limitate.

Parte 2: scalare i Transformer

È importante comprendere ogni parte dell'architettura Transformer: le dimensioni esatte di ogni matrice, dove si verifica la normalizzazione e quanti parametri e FLOP sono presenti in ogni parte. Questa parte esamina attentamente la matematica di Transformer, mostrando come contare i parametri e i FLOP sia per l'addestramento sia per l'inferenza. In questo modo saprai quanta memoria utilizzerà il modello, quanto tempo dedicherai al calcolo o alle comunicazioni e quando l'attenzione diventerà importante rispetto ai blocchi feed-forward.

Infine, questa parte ti aiuta a rispondere alla domanda fondamentale: dato un modello di una dimensione specifica e fornito un certo numero di chip, come parallelizzare il modello per rimanere nella condizione di scaling forte. Per rispondere a questa domanda, questa parte illustra le quattro tecniche di parallelismo principali utilizzate per dividere i modelli su più chip: dati, tensori, pipeline ed esperti. Descrive anche altre tecniche per ridurre i requisiti di memoria, come la rimaterializzazione, lo sharding dei modelli basato su ZeRO, l'offload dell'host e l'accumulo di gradienti.

Parte 3: implementazione pratica

Questa parte descrive come implementare i concetti di scaling utilizzando JAX e come profilare ed eseguire il debug del codice in caso di problemi.

  • Profilazione dei programmi TPU: i modelli LLM reali sono complessi e difficili da sviluppare, ottimizzare ed eseguire il debug. Questa sezione spiega lo stack JAX + XLA e come utilizzare il profiler JAX/TensorBoard per eseguire il debug e risolvere problemi reali.
  • Programmazione delle TPU in JAX: Questa sezione descrive come utilizzare le API JAX per parallelizzare il calcolo.