Leistungsoptimierungen für TPU7x (Ironwood)

In diesem Leitfaden werden verschiedene Methoden zur Optimierung der Leistung mit TPU7x (Ironwood) beschrieben, indem der Datenverkehr zwischen dem mehrstufigen Speichersystem effizient verwaltet wird. Dazu gehören Techniken wie Training mit niedriger Präzision, Sharding, Kommunikationsoptimierung, Rematerialisierung von Aktivierungen, Optimierung des virtuellen Speichers und benutzerdefinierte Accelerator-Kernel.

Um die Leistung mit TPU7x zu optimieren, müssen Sie zuerst mit der Ironwood-Architektur vertraut sein, insbesondere mit der Speicherhierarchie und der Interconnect-Topologie. Weitere Informationen finden Sie unter TPU7x (Ironwood).

Training mit niedriger Genauigkeit mit FP8

FP8 (8-Bit-Gleitkomma) ist ein effizientes numerisches Datenformat, das hauptsächlich zur Beschleunigung von Modelltraining und ‑inferenz verwendet wird. Durch die Darstellung von Zahlen mit 8 Bit – anstelle der Standardformate mit 16 Bit (FP16 oder BF16) und 32 Bit (FP32) – können TPUs Daten deutlich schneller verarbeiten und benötigen weniger Arbeitsspeicher.

TPU7x unterstützt die integrierte Hardwarebeschleunigung für FP8-Datentypen und bietet eine theoretische Spitzenleistung von 4614 TFLOPS pro Chip. Dies kann zu deutlich kürzeren End-to-End-Trainingszeiten führen. Bei kompatiblen Vorgängen, insbesondere bei dichten Matrixmultiplikationen, die für KI-Arbeitslasten üblich sind, kann die Verwendung von FP8 zu Leistungssteigerungen von 1,3-mal im Vergleich zum standardmäßigen BF16-Training führen. Im Vergleich zu BF16 verdoppelt FP8 die Spitzen-FLOPs und halbiert den Speicherbedarf für Gewichte und Aktivierungen. FP8 sollte ein primäres Tuning-Instrument für rechenintensive Arbeitslasten und Szenarien sein, die durch Speicherkapazität oder Bandbreite eingeschränkt sind.

Die Verwendung von FP8 bietet die folgenden Leistungsvorteile:

  • Geringerer Druck auf den Arbeitsspeicher mit hoher Bandbreite (HBM): Durch einen geringeren Speicherbedarf können größere Modelle oder Modelle mit größeren Schlüssel/Wert-Caches während der Inferenz vollständig in den 192 GB HBM passen. So wird ein kostspieliges Auslagern in den langsameren Hostspeicher vermieden.
  • Größere effektive Batchgröße: Durch die Reduzierung des für Aktivierungen erforderlichen Speichers ermöglicht FP8 die Verwendung größerer Batchgrößen. Dadurch wird die Datenparallelität verbessert, was zu einem höheren Durchsatz und einer besseren Auslastung der Recheneinheiten führen kann.
  • Geringere Anforderungen an die Speicherbandbreite: Wenn für jeden Vorgang nur die Hälfte der Daten verschoben wird, sinkt die Belastung des Datenpfads zwischen HBM und MXU. Auf Systemen, auf denen die Datenübertragung häufig einen Engpass darstellt, trägt dies dazu bei, dass die MXUs ausgelastet bleiben.

Die Verwendung von FP8 mit keiner oder nur geringer Leistungseinbuße erfordert die sorgfältige Auswahl von Quantisierungstechniken. Hier sind einige Best Practices für das FP8-Training:

  • Granularität der Skalierung: Beginnen Sie mit der Skalierung pro Tensor als Baseline. Wenn es Probleme mit der Qualität oder Leistung gibt, wechseln Sie zur achsenweisen Skalierung. Eine Skalierung des untergeordneten Channels ist möglicherweise nicht erforderlich.
  • Skalierungsmodus: Die dynamische Skalierung, bei der Skalierungsfaktoren zur Laufzeit berechnet werden, ist eine gute Standardeinstellung, um die Qualität beizubehalten. Die statische Skalierung kann die Leistung erheblich steigern, da Berechnungen entfallen. Allerdings ist eine sorgfältige Profilerstellung erforderlich, um die richtigen Skalierungsfaktoren zu ermitteln. Außerdem ist sie möglicherweise nicht für alle Anwendungsfälle geeignet, insbesondere wenn sich die Modellkonfigurationen ändern. Umgekehrt können einige robuste Modelle und Konfigurationen den Maßstab auf das FP8-Limit für Gewichte oder Aktivierungen festlegen. So können Sie den Quantisierungs-Overhead reduzieren und gleichzeitig die Genauigkeit beibehalten und die Leistung verbessern.
  • FP8-Formate (E4M3 und E5M2): Ein gängiger und effektiver Ansatz ist die Verwendung einer Mischung aus FP8-Formaten. Verwenden Sie beispielsweise E4M3 für Gewichte und Aktivierungen im Forward-Pass, um die höhere Genauigkeit von E4M3 zu nutzen, und E5M2 für Gradienten im Backward-Pass, um den größeren dynamischen Bereich der Gradienten zu berücksichtigen.
  • Runden: Wenn Sie für Gradienten „Auf die nächste gerade Zahl runden“ (Round to Nearest Even, RNE) anstelle von stochastischem Runden verwenden, kann die Qualität beibehalten werden und gleichzeitig eine bessere Leistung und Reproduzierbarkeit erzielt werden.
  • FP8 in MaxText aktivieren: MaxText unterstützt das FP8-Training über die QWIX-Quantisierungsbibliothek. Um die Quantisierung zu aktivieren, legen Sie das folgende Flag in Ihrer Konfiguration fest: use_qwix_quantization=true.

Fragmentierung und Parallelität

Beim Sharding wird ein großes Modell oder seine Trainingsdaten in kleinere Teile zerlegt und auf mehrere TPU-Chips oder ‑Kerne verteilt. Die Auswahl der richtigen Sharding-Strategie ist wichtig, um eine hohe Leistung auf TPU7x zu erzielen.

Ein naiver Ansatz, der nur den Grad der Parallelität maximiert, führt oft zu einer schlechten Leistung, da er kommunikationsgebunden wird. Am besten wählen Sie die einfachste Sharding-Strategie aus, die die Speicheranforderungen erfüllt, da so der Kommunikationsaufwand minimiert und die Recheneinheiten effizient genutzt werden können.

Bevor Sie eine Sharding-Strategie auswählen, sollten Sie als ersten Schritt bei der Leistungsoptimierung eine Analyse der arithmetischen Intensität durchführen. Bei dieser Analyse wird ermittelt, ob eine bestimmte Berechnung durch Rechenleistung, Speicherbandbreite oder Interconnect-Bandbreite begrenzt wird. Er wird als Verhältnis von Gleitkommaoperationen zu den Bytes an Daten berechnet, die verschoben werden müssen.

Eine hohe arithmetische Intensität deutet auf eine rechengebundene Arbeitslast hin. Eine niedrige arithmetische Intensität deutet auf eine speicher- oder kommunikationsgebundene Arbeitslast hin, bei der die Leistung durch die Geschwindigkeit begrenzt wird, mit der Daten aus dem HBM oder über das ICI-Netzwerk übertragen werden können. Diese Analyse liefert Informationen zur idealen Batch-Größe und Sharding-Strategie. Eine kommunikationsintensive Arbeitslast profitiert beispielsweise nicht von einer Sharding-Strategie, die noch mehr Kommunikation erfordert, z. B. Tensorparallelität mit hohem Grad.

Entscheidungsframework für die Sharding-Strategie

MaxText bietet eine Vielzahl von Sharding-Strategien. Die optimale Wahl hängt von der Modellarchitektur, der Sequenzlänge und der Notwendigkeit ab, die Rechenlast mit dem Kommunikationsaufwand in Einklang zu bringen.

  • Vollständig fragmentierte Datenparallelität (Fully Sharded Data Parallelism, FSDP): Dies ist die bevorzugte Standardstrategie für die Datenparallelität. Bei FSDP werden die Modellgewichte, Gradienten und Optimiererstatus auf die datenparallelen Geräte verteilt. Während der Berechnung führt jedes Gerät einen All-Gather-Vorgang aus, um die erforderlichen vollständigen Gewichte für seinen lokalen Microbatch abzurufen. FSDP ist sehr effektiv, solange die Batchgröße pro Gerät groß genug ist, um die Latenz dieser All-Gather-Kommunikation zu verbergen. Bei MoE-Modellen (Mixture-of-Experts) muss bei der Berechnung der arithmetischen Intensität die Sparsity berücksichtigt werden.
  • Tensor-Parallelität (TP): Bei TP werden einzelne Tensoren auf mehrere Geräte verteilt. In der Regel sind die Tensoren Gewichtsmatrizen in MLP- und Attention-Blöcken. Die hohe arithmetische Intensität der Hardware (11,5k) stellt sehr hohe Anforderungen an die Dimensionen des Modells, damit TP gegenüber ICI praktikabel ist.Der Versuch, TP zu verwenden, kann dazu führen, dass das System kommunikationsgebunden ist.
  • Expert Parallelism (EP): Dies ist die Standard- und notwendige Strategie für das Trainieren von MoE-Modellen. EP unterteilt die „Expert“-Ebenen auf eine Reihe von Geräten und ein All-to-All-Kommunikationskollektiv wird verwendet, um Tokens an das dafür vorgesehene Expert-Gerät weiterzuleiten. EP kann effizient sein, wenn die MLP-Dimension des Modells groß genug ist, um sich der Roofline anzunähern.
  • Kontextparallelität (Context Parallelism, CP): CP ist eine spezielle Strategie, die für das Training von Modellen mit sehr langen Sequenzlängen unerlässlich ist. Die primäre Funktion besteht darin, die Speichernutzung von Aktivierungen zu verwalten, die quadratisch mit der Sequenzlänge ansteigt und die HBM-Kapazität überschreiten kann. CP unterteilt die Sequenzdimension der Aktivierungstensoren in Shards, was die Verwendung einer fraktionierten Batchgröße pro Gerät ermöglicht. Da CP mehr Kommunikation als FSDP erfordert, gilt die allgemeine Regel, den Mindestgrad an CP zu verwenden, der erforderlich ist, um die Speicherbeschränkungen zu erfüllen und sicherzustellen, dass der Batchachsen-Shard eine Ganzzahl bleibt.

In der folgenden Tabelle werden gängige Arbeitslasttypen der optimalen Sharding-Strategie zugeordnet:

Arbeitslasttyp Empfohlenes primäres Sharding Sekundäre Fragmentierung Wichtige Engpässe Begründung
Kompaktes Modell – kurze Sequenz FSDP Rematerialisierung, FF-Matmuls FSDP bietet das beste Gleichgewicht. Bei kurzen Sequenzen ist der Aktivierungsspeicher möglicherweise kein großes Problem. Der Schlüssel ist ein ausreichend großer globaler Batch, um das All-Gather-Gewicht von FSDP zu verbergen. Mit zunehmender Batchgröße nimmt auch die Aktivierungsgröße zu. Eine geeignete Rematerialisierungsrichtlinie ist erforderlich, damit bei dieser Konfiguration nicht der Speicher ausgeht.
Dense-Modell – lange Sequenz FSDP CP Flash-Attention, Aktivierungsspeicher Der Aktivierungsspeicher wird zur primären Einschränkung. CP ist erforderlich, um fraktionierte Batchgrößen pro Gerät zu ermöglichen und Probleme mit fehlendem Arbeitsspeicher zu vermeiden. Flash Attention ist die dominierende Quelle für Rechenleistung und verschwendete Zeit.
MoE-Modell – kurze Sequenz FSDP + EP All-to-All (Expert Routing), Rematerialisierung Für MoE-Modelle ist EP erforderlich, um die Experten zu partitionieren. Die All-to-All-Kommunikation für das Token-Routing ist ein großer Engpass, der überlappt werden muss. Auch die Rematerialisierung ist eine erhebliche Quelle für Abfall.
MoE-Modell – sehr großer Maßstab FSDP + EP + PP Modellparallelität (MP) Alle zuvor erwähnten Engpässe sowie Pipeline-Bubbles Bei Modellen, die den Arbeitsspeicher eines einzelnen Pods überschreiten, ist PP erforderlich, um Ebenen auf mehrere Pods aufzuteilen. Dadurch werden DCN-Kommunikation und Pipeline-Bubble-Overheads eingeführt. Dies ist eine sehr komplexe Konfiguration, die sorgfältig abgestimmt werden muss.

Kommunikationsoptimierung

Der primäre Mechanismus für die Überlappung von Kommunikation und Berechnung auf TPU7x heißt SparseCore Collective Offloading. Die Ironwood-Architektur umfasst dedizierte SparseCore-Einheiten, die als unabhängige Steuerungs-Threads fungieren und in der Lage sind, Datenbewegungen über das ICI-Fabric zu verwalten. So können kollektive Kommunikationsvorgänge (z. B. „All-Gather“ oder „Reduce-Scatter“) parallel zu den Hauptberechnungen auf den Tensor-Cores ausgeführt werden. Dies ist die empfohlene Methode für asynchrone Kollektive auf TPU7x. Verwenden Sie die empfohlenen Flags, um das Offloading für die gängigsten Kollektive zu aktivieren.

Reaktivierung

Die Re-Materialisierung von Aktivierungen, auch als Gradienten-Checkpointing bezeichnet, ist eine grundlegende Methode zur Reduzierung des HBM-Speicherbedarfs eines Modells. Anstatt alle Zwischenaktivierungen aus dem Vorwärtsdurchlauf in HBM zu speichern, um sie während des Rückwärtsdurchlaufs zu verwenden, werden nur einige wichtige Aktivierungen (Checkpoints) gespeichert und die anderen werden bei Bedarf während des Rückwärtsdurchlaufs neu berechnet. Dadurch wird viel Arbeitsspeicher gespart, allerdings auf Kosten einer erhöhten Rechenleistung (ca. 25–30% zusätzliche FLOPs für einen Standard-Transformer-Block).

Die Entscheidung, wie aggressiv die Rematerialisierung angewendet werden soll, ist ein wichtiger Abstimmungsparameter, der vollständig vom primären Engpass abhängt, der oft mit der Sequenzlänge variiert.

Bei Arbeitslasten mit langen Sequenzen (z. B. 128.000): In diesen Fällen ist die Größe der Aktivierungstensoren der Hauptverbraucher von HBM. Die Arbeitslast ist in der Regel speichergebunden. Daher ist es sehr sinnvoll, eine aggressive Rematerialisierungsrichtlinie anzuwenden. Durch die Arbeitsspeichereinsparungen kann das Training ohne Fehler aufgrund mangelnden Arbeitsspeichers fortgesetzt werden. Außerdem sind größere Batchgrößen möglich. Der Rechenaufwand für die Neuberechnung ist ein lohnender Kompromiss.

Bei Arbeitslasten mit kurzen Sequenzen (z. B. 8k): In diesen Fällen ist der Aktivierungsspeicher viel weniger wichtig und die Arbeitslast ist eher rechengebunden. Der Rechenaufwand für die Rematerialisierung kann die größte Quelle für Ineffizienz sein.

Richtlinien für die Rematerialisierung in MaxText optimieren

MaxText bietet mit einer Reihe von voreingestellten und benutzerdefinierten Richtlinien, die mit dem Flag remat_policy konfiguriert werden, eine detaillierte Steuerung der Rematerialisierung.

Voreingestellte Richtlinien

MaxText bietet die folgenden integrierten Richtlinien:

  • full: Die aggressivste Richtlinie, bei der fast alles wiederhergestellt wird. Dadurch wird die HBM-Nutzung minimiert, aber der Aufwand für die Neuberechnung maximiert. Ideal für Szenarien mit extrem wenig Arbeitsspeicher und langen Sequenzen.
  • minimal: Die am wenigsten aggressive Richtlinie, bei der die meisten Aktivierungen gespeichert werden. Dadurch wird die HBM-Nutzung maximiert, aber die Neuberechnung minimiert. Am besten geeignet für kurze Sequenzen, rechenintensive Arbeitslasten, bei denen der Arbeitsspeicher keine Rolle spielt.
  • Zwischenrichtlinien: Optionen wie save_dot_with_context_except_mlp, save_qkv_proj und save_out_proj bieten verschiedene Kompromisse, indem die Ausgaben von rechenintensiven Dot-Product-Operationen selektiv geprüft werden, während kostengünstigere elementweise Operationen rematerialisiert werden.

Benutzerdefinierte Richtlinien

Für mehr Kontrolle können Sie remat_policy auf custom setzen. So können Sie das Verhalten für einzelne Ebenen im Decodierungsmodul des Modells festlegen. Jeder Ebene kann eines von drei Verhaltensmustern zugewiesen werden:

  • device: Die Aktivierung wird im HBM auf dem TPU-Gerät gespeichert.
  • remat: Die Aktivierung wird verworfen und während des Rückwärtsdurchlaufs neu materialisiert.
  • offload: Die Aktivierung wird vom HBM in den Arbeitsspeicher des CPU-Hosts verschoben. Dadurch wird HBM freigegeben, allerdings auf Kosten der PCIe-Übertragungslatenz.

Abgestimmte VMEM-Abstimmung

Die Kernelleistung, z. B. Flash Attention, hängt von den ausgewählten Kachelgrößen im Kernel ab. Die Größe ist durch den verfügbaren Vektorspeicher (VMEM) begrenzt. TPU7x-Chips haben 64 MB VMEM, die zwischen dem aktuellen Bereich (Scoped VMEM) und dem zukünftigen Gewichtsvorabruf aufgeteilt werden können. Durch Erhöhen des bereichsbezogenen VMEM können die Kachelgrößen im Kernel erhöht werden, wodurch möglicherweise Speicherblockierungen reduziert und die Leistung von Kernels gesteigert wird. Sie können die Größe des VMEM mit beschränktem Umfang ändern, indem Sie xla_tpu_scoped_vmem_limit_kib (in LIBTPU_INIT_ARGS) festlegen. So können Sie die Kernel- und End-to-End-Leistungsgrenzen untersuchen. Die Optimierung der Größe des VMEM-Bereichs kann sich indirekt auf die Leistung benutzerdefinierter Pallas-Kernel auswirken, da durch die Vergrößerung des VMEM-Bereichs ein größerer Hyperparameter-Suchbereich für die Kachelgrößen im Kernel verfügbar wird.

Tokamax-Kernel

Tokamax, eine leistungsstarke JAX-Kernel-Bibliothek mit vielen hochoptimierten TPU-Kernels, behebt mehrere häufige hardwarespezifische Engpässe:

  • Splash-Attention: Splash-Attention wird als primäre Attention-Implementierung verwendet, um den HBM-Engpass von Standard-Attention zu beseitigen. Außerdem wird die effizienteste Attention-Implementierung auf TPUs verwendet.
  • Gruppierte Matrixmultiplikation (Grouped Matrix Multiplication, GMM) mit Megablox: Bei MoE-Arbeitslasten verarbeitet Megablox gruppierte Matrixmultiplikationen effizient, indem Berechnungen für die Darstellung mit unregelmäßigen Aktivierungen durchgeführt werden. Sie wird effizient auf die unregelmäßige Dimension abgebildet und berechnet Matrixmultiplikationen zwischen unregelmäßigen Gruppen von Zeilen in LHS und der entsprechenden Expertenmatrix. So ist es nicht erforderlich, Batches auf eine feste Größe aufzufüllen.
  • Empirische Optimierung mit tune-jax: Die tune-jax-Bibliothek enthält Dienstprogramme für die empirische Suche nach optimalen Blockgrößen. Die Standardkernelgrößen sind oft suboptimal. Durch die Optimierung können Sie hardwarefreundliche VMEM-Kachelgrößen auswählen, um die Hardwareauslastung zu maximieren.
  • Maximale Logits-Schätzung: Der Tokamax Splash-Attention-Kernel kann weiter optimiert werden, indem Sie einen Wert für max_logit_const festlegen. Falls festgelegt, ersetzt sie die Berechnung der Reduzierung des maximalen Logits während der Softmax-Operation der Aufmerksamkeit (softmax(Q * KT)), wodurch der Rechen- und Synchronisierungsaufwand reduziert wird. In MaxText wird sie durch die Konfiguration use_max_logits_estimate implementiert, die auf None (deaktiviert) oder einen Gleitkommawert festgelegt werden kann. Prüfen Sie, ob der Logit-Bereich Ihres spezifischen Modells mit der Schätzung kompatibel ist, um einen numerischen Überlauf zu vermeiden. Wenn dieser Wert festgelegt ist, wird ein Konvergenztest empfohlen.