Modell auf TPUs skalieren
In diesem Dokument wird beschrieben, wie Sie Sprachmodelle skalieren können: wie TPUs funktionieren und wie sie miteinander kommunizieren, wie LLMs auf echter Hardware ausgeführt werden und wie Sie Ihre Modelle während des Trainings und der Inferenz parallelisieren, damit sie effizient in großem Maßstab ausgeführt werden können. Wir stellen Informationen zur Verfügung, mit denen Sie einschätzen können, wie teuer das Trainieren eines LLM ist, wie viel Speicher Sie für die Bereitstellung des Modells benötigen und wie Sie Modelle effektiv auf mehrere TPUs aufteilen können.
Viele Aspekte von Deep Learning sind komplex, aber die Optimierung der Leistung Ihrer Modelle muss es nicht sein, auch nicht im großen Maßstab. Die grundlegenden Prinzipien gelten überall – von einem einzelnen Beschleuniger bis hin zu Zehntausenden. Wenn Sie sie verstehen, können Sie viele nützliche Dinge tun:
- Schätzen Sie, wie nah die einzelnen Teile Ihres Modells an ihrem theoretischen Optimum liegen.
- Fundierte Entscheidungen zu verschiedenen Parallelisierungsschemas in unterschiedlichen Größenordnungen treffen (wie Sie die Berechnung auf mehrere Geräte aufteilen).
- Schätzen Sie die Kosten und den Zeitaufwand, die für das Training und die Ausführung großer Transformer-Modelle erforderlich sind.
- Entwerfen Sie Algorithmen, die die TPU-Architektur nutzen.
- Entwerfen Sie Modellarchitekturen, die auf einem expliziten Verständnis der Faktoren beruhen, die die Algorithmusleistung einschränken.
Vorbereitung
Sie sollten ein grundlegendes Verständnis von LLMs und der Transformer-Architektur haben, aber nicht unbedingt wissen, wie sie in großem Maßstab funktionieren. Sie sollten die Grundlagen des LLM-Trainings kennen und idealerweise über grundlegende Kenntnisse in JAX verfügen. Nützliche Hintergrundinformationen zur Transformer-Architektur finden Sie hier:
- The Illustrated Transformer: Blogpost zur Transformer-Architektur
- Attention Is All You Need: das ursprüngliche Transformer-Paper
Nachdem Sie sich mit diesen Voraussetzungen vertraut gemacht haben, sollten Sie in der Lage sein, das beste Parallelitätsschema für ein Transformer-Modell auf einer bestimmten TPU-Plattform zu schätzen. Außerdem können Sie schätzen, wie lange Training und Inferenz dauern sollten.
Bedeutung der Modellskalierung
LLMs und die meisten kleinen Modelle laufen heute so nah an den Hardwaregrenzen, dass Sie bei der Entwicklung von Modellen an Effizienz im großen Maßstab denken müssen. Ein Benchmark-Vorteil von 20% ist irrelevant, wenn er mit einem Verlust von 20% bei der Roofline-Effizienz einhergeht. Vielversprechende Modellarchitekturen scheitern regelmäßig, weil sie nicht effizient skaliert werden können oder weil es an Optimierungsbemühungen mangelt, um dies zu ermöglichen.
Ziel der Modellskalierung ist es, die Anzahl der für das Training oder die Inferenz verwendeten Chips zu erhöhen und gleichzeitig einen proportionalen, linearen Anstieg des Durchsatzes zu erzielen. Dies wird als „Strong Scaling“ bezeichnet. Durch das Hinzufügen zusätzlicher Chips (Parallelität) wird die Berechnungszeit in der Regel verkürzt, es kommt aber auch zu einer zusätzlichen Kommunikation zwischen den Chips. Wenn die Kommunikation länger dauert als die Berechnung, wird das Modell durch die Kommunikation eingeschränkt und lässt sich nicht gut skalieren. Wenn Sie die Hardware gut genug kennen, um vorherzusehen, wo diese Engpässe auftreten, können Sie Ihre Modelle so gestalten oder neu konfigurieren, dass sie vermieden werden.
In den folgenden Abschnitten erhalten Sie einen Überblick darüber, wie TPU-Hardware skaliert wird und wie sich die Transformer-Architektur entwickelt hat. Diese Informationen sind sowohl für Forscher, die neue Architekturen entwerfen, als auch für Ingenieure nützlich, die daran arbeiten, die aktuelle Generation von LLMs schnell auszuführen.
Teil 1: Konzepte
In diesem Teil werden die Roofline-Analyse und die Faktoren erläutert, die die Skalierbarkeit eines Modells einschränken (Kommunikation, Berechnung und Arbeitsspeicher). Als Nächstes beschreiben wir, wie TPUs funktionieren, sowohl als einzelne Chips als auch – was von entscheidender Bedeutung ist – als vernetztes System mit Inter-Chip-Verbindungen mit begrenzter Bandbreite und Latenz.
- Einführung in die Roofline-Analyse: In diesem Abschnitt wird beschrieben, wie Sie anhand von Rechen-, Kommunikations- und Speicherlimits abschätzen können, wie schnell Ihr Algorithmus ausgeführt wird.
- Vorgänge in der TPU-Architektur: In diesem Abschnitt wird die Architektur von TPUs beschrieben. Außerdem wird erläutert, wie verschiedene Hardwaremodule in TPUs funktionieren und wie sie sich auf das Modelltraining und die Bereitstellung auswirken.
- Modellfragmentierung für die Parallelität mehrerer TPUs: In diesem Abschnitt werden die Modellfragmentierung und die Parallelität mehrerer TPUs anhand von fragmentierten Matrixmultiplikationen erläutert.
Teil 2: Transformers skalieren
Es ist wichtig, jeden Teil der Transformer-Architektur zu verstehen: die genauen Größen jeder Matrix, wo die Normalisierung erfolgt und wie viele Parameter und FLOPS in jedem Teil enthalten sind. In diesem Teil wird die Transformer-Mathematik sorgfältig durchgegangen und es wird gezeigt, wie die Parameter und FLOPs für Training und Inferenz gezählt werden. So erfahren Sie, wie viel Arbeitsspeicher Ihr Modell benötigt, wie viel Zeit Sie für Berechnungen oder Kommunikation aufwenden und wann die Aufmerksamkeit im Verhältnis zu den Feedforward-Blöcken wichtig wird.
Schließlich hilft Ihnen dieser Teil, die grundlegende Frage zu beantworten: Wie kann ein Modell einer bestimmten Größe, das mit einer bestimmten Anzahl von Chips bereitgestellt wird, parallelisiert werden, um die Bedingung für die starke Skalierung zu erfüllen? Um diese Frage zu beantworten, werden in diesem Teil die vier wichtigsten Parallelisierungstechniken behandelt, die zum Aufteilen von Modellen auf mehrere Chips verwendet werden: Daten-, Tensor-, Pipeline- und Expertenparallelisierung. Außerdem werden andere Techniken zur Reduzierung des Speicherbedarfs beschrieben, z. B. Rematerialisierung, ZeRO-basiertes Modell-Sharding, Host-Offload und Gradientenakkumulierung.
- Einführung in die mathematischen Operationen von Transformer-Modellen: In diesem Abschnitt wird die Mathematik durchgegangen, um Fragen zur Anzahl der FLOPs zu beantworten, die von einem Transformer während Vorwärts- und Rückwärtsdurchläufen verwendet werden, Berechnungen zur Berechnung der Anzahl der Parameter und der Größe von KV-Caches.
- Transformer-Parallelisierung für das Training: In diesem Abschnitt wird beschrieben, wie Sie die Trainingseffizienz maximieren, indem Sie FSDP, Megatron-Sharding und Pipeline-Parallelität koordinieren. Darin wird beschrieben, wie Sie die optimale Verteilung für eine bestimmte Modellgröße und Batchgröße auf eine feste Anzahl von Chips bestimmen, um den maximalen Durchsatz zu erzielen.
- Llama 3 auf TPUs trainieren: In diesem Unterabschnitt wird beschrieben, wie Sie Llama 3 auf TPUs trainieren, wie lange das dauern kann und wie viel es kosten kann.
- Transformer-Skalierung für die Inferenz: Nachdem ein Modell trainiert wurde, muss es bereitgestellt werden. Die Inferenz bringt eine neue Überlegung mit sich, die Latenz, und verändert die Speichersituation. In diesem Abschnitt wird beschrieben, wie die disaggregierte Bereitstellung funktioniert und wie Sie KV-Caches verwenden können.
- Llama 3 auf TPUs bereitstellen: In diesem Unterabschnitt wird beschrieben, wie Sie Llama 3 auf TPUs bereitstellen, wie viel das kosten könnte und welche Kompromisse bei Latenz und Durchsatz bestehen.
Teil 3: Praktische Umsetzung
In diesem Teil wird beschrieben, wie Sie die Skalierungskonzepte mit JAX implementieren und wie Sie Ihr Programm profilieren und debuggen, wenn etwas schiefgeht.
- TPU-Programme profilieren: Echte LLMs sind komplex und lassen sich nur schwer entwickeln, optimieren und debuggen. In diesem Abschnitt wird der JAX- und XLA-Stack erläutert und es wird beschrieben, wie Sie den JAX-/TensorBoard-Profiler verwenden, um echte Probleme zu beheben.
- TPUs in JAX programmieren: In diesem Abschnitt wird beschrieben, wie Sie die JAX-APIs verwenden, um Berechnungen zu parallelisieren.