Aufgrund der verteilten Natur von JAX mit Pathways lassen sich einige Vorgänge aufgrund von Kommunikations-Overheads möglicherweise nicht gut skalieren. Pathways minimiert diese Overheads zwar mit Funktionen wie dem asynchronen Versand, es gibt jedoch einige Dinge, die Sie beachten müssen, wenn Sie JAX-Arbeitslasten zu Pathways portieren oder eine JAX-Arbeitslast mit Pathways auf eine große Anzahl von Beschleunigern skalieren.
Hinweis
Sie benötigen Folgendes:
- Kubernetes-Tools installiert
- gcloud CLI installiert
- TPU API aktiviert
- Google Kubernetes Engine API aktiviert
Prozessindex
Bei JAX mit Pathways werden alle Geräte in Ihrem Pathways-Cluster als lokal behandelt. Dadurch wird die Geräteverwaltung vereinfacht und JAX kann alle verfügbaren Ressourcen nutzen. In der Praxis bedeutet das:
jax.process_index()ist für alle Geräte immer 0.jax.devices()undjax.local_devices()geben alle TPU-Geräte für den gesamten Job zurück.
Hardwaretyp und Colocation
Für eine optimale Leistung sollten sich alle Pathways-Komponenten und der Nutzerjob in derselben
Cloud-Zone Google Cloud befinden. Verwenden Sie eine große CPU wie den IFRT-Proxy und den Ressourcenmanager. Wir empfehlen mindestens eine dedizierte n2-standard-64 mit 64 vCPUs und 256 GB Arbeitsspeicher.
PathwaysUtils
Pathways-utils ist ein Python-basiertes GitHub-Repository mit wichtigen Dienstprogrammen und Tools, mit denen Sie die Bereitstellung und Ausführung von JAX-Arbeitslasten in der Pathways on Cloud-Architektur optimieren können. Dieses Paket übernimmt die erforderlichen Anpassungen für die Cloud-Umgebung, sodass sich JAX-Entwickler mit minimaler plattformspezifischer Konfiguration auf ihre wichtigsten Machine-Learning-Arbeitsabläufe konzentrieren können. Insbesondere bietet es Folgendes:
- Ein „Proxy“-JAX-Backend: Mit diesem benutzerdefinierten Backend kann Ihre JAX-Anwendung die Pathways-Infrastruktur nutzen, indem Sie die Umgebungsvariable
JAX_PLATFORMS=proxyfestlegen. - Integrierte Profilerstellungsdienstprogramme: Profilerstellungsfunktionen, mit denen Sie die Leistung Ihrer Anwendung analysieren können. Mit Standard-JAX-Profilerstellungs-APIs wie
jax.profiler.start_traceundjax.profiler.start_serverkönnen Sie nicht nur Ihren JAX-Code, sondern auch die zugrunde liegenden Pathways-Komponenten profilieren und so einen ganzheitlichen Überblick über die Ausführung in der Cloud-Umgebung erhalten. - Verteilte Prüfpunkte mit Orbax: Ein benutzerdefinierter Orbax-Prüfpunkthandler, mit dem Sie verteilte Prüfpunkte verwenden und Ihre Prüfpunkte wiederherstellen können, wenn Sie die Orbax-Bibliothek in der Pathways-Umgebung verwenden. Diese Integration soll ohne Änderungen an Ihrem vorhandenen Orbax-Prüfpunktcode funktionieren, sofern
pathwaysutilsimportiert wird. - Elastische Trainingsprimitive: Bietet grundlegende elastische Trainingsprimitive, mit denen Sie robuste und skalierbare Trainingsarbeitsabläufe mit Pathways erstellen können. Mit diesen Primitiven können Ihre Trainingsjobs sich dynamisch an Änderungen der verfügbaren Ressourcen anpassen, wodurch die Effizienz und Stabilität in Cloud-Umgebungen verbessert werden.
Prüfpunkte
Orbax wird mit Pathways umfassend für
verteilte Prüfpunkte und die Wiederherstellung mit Cloud Storage getestet. Wenn Sie
die Funktion import pathwaysutils; pathwaysutils.initialize() in Ihrer train.py aufrufen, wird ein benutzerdefinierter
ArrayHandler registriert, der Prüfpunktvorgänge effizient
über den IFRT
Proxy verarbeitet. So können Pathways-Worker auf Beschleunigern Daten direkt speichern und wiederherstellen.
Colocated Python
Colocated Python ist eine Open-Source-JAX-API, mit der Sie nutzerdefinierten Python-Code direkt auf den TPU- oder GPU-Hosts ausführen können. Das ist in JAX mit mehreren Controllern einfacher. So können rechenintensive Aufgaben wie das Laden von Daten und Prüfpunkte ohne Datenübertragung zwischen dem Client und den TPU-Maschinen ausgeführt werden. Wenn Sie Ihren Pathways-Cluster so konfigurieren möchten, dass die colocated Python JAX API ausgeführt wird, folgen Sie der Anleitung in der README-Datei für colocated Python. Dort wird erklärt, wie Sie einen colocated Python-Sidecar neben Ihren Pathways-Workern starten.
Laden der Daten
Während des Trainings werden wiederholt Batches aus einem Dataset geladen, um sie in das Modell einzuspeisen. Ein effizienter, asynchroner Data Loader, der den Batch auf Hosts verteilt, ist wichtig, um unterausgelastete Beschleuniger zu vermeiden. Wenn Sie das Training mit Pathways ausführen, wird der Data Loader auf einer CPU-VM ausgeführt (im Gegensatz zu einer TPU-VM, die bei Setups mit mehreren Controllern verwendet wird) und sendet Daten an TPU-VMs. Dadurch erhöht sich die Latenz beim Lesen von Daten, was jedoch teilweise dadurch gemildert wird, dass eine bestimmte Anzahl von Batches im Voraus auf dem CPU-Host gelesen und die gelesenen Daten asynchron an die TPUs gesendet werden. Diese Lösung ist ausreichend, wenn Sie in kleinem bis mittlerem Maßstab arbeiten.
Für eine optimale Leistung im großen Maßstab empfehlen wir dringend, Ihre Eingabe datenpipeline gemeinsam zu platzieren, indem Sie colocated Python verwenden, um Ihre Datenpipeline direkt auf den Beschleunigern auszuführen. Dadurch wird der CPU-Engpass beseitigt und die schnellen TPU-Verbindungen für die Datenübertragung genutzt.
Eine Referenzimplementierung der Migration einer TFDS-basierten
Eingabepipeline finden Sie in der RemoteIterator Implementierung in
multihost_dataloading.py.
Diese Implementierung funktioniert sowohl in JAX mit mehreren Controllern als auch in Pathways auf verteilte Weise mit der colocated Python JAX API.
JAX-Versionsverwaltung
Pathways-Releases sind eng mit JAX-Versionen verknüpft, um Kompatibilität und Stabilität zu gewährleisten. Um potenzielle Probleme zu vermeiden, prüfen Sie, ob Ihre Pathways-Artefakte und Ihre JAX-Version übereinstimmen. In jedem Pathways-Release werden die
kompatiblen JAX-Versionen durch ein Tag im Format jax-<version> angegeben.
Kompilierungscache
Der persistente Kompilierungscache von Pathways ist eine Funktion, mit der Pathways-Server kompilierte XLA-Ausführungsdateien an einem persistenten Speicherort wie Cloud Storage speichern können, um redundante Kompilierungen zu vermeiden. Diese Funktion ist standardmäßig aktiviert. Der Speicherort des Cache wird als Flag --gcs_scratch_location an den Ressourcenmanager und die Pathways-Worker-Container übergeben. Um die zugehörigen Speicherkosten zu minimieren, wird dem Cache eine Lebenszyklusrichtlinie für den Cloud Storage-Speicherort angehängt. Es gibt ein Limit von 50 Richtlinien pro Cloud Storage-Bucket. Daher empfehlen wir, einen gemeinsamen Cloud Storage-Speicherort für alle Arbeitslasten zu verwenden.
Dieser Cache ähnelt dem JAX-Kompilierungscache
der von pathwaysutils.initialize() für Pathways-Arbeitslasten deaktiviert wird.
Für den Kompilierungscache sind die folgenden Cloud Storage-Berechtigungen erforderlich:
storage.buckets.get: Zum Abrufen von Bucket-Metadaten.storage.buckets.update: Erforderlich, damit Pathways Lebenszyklusrichtlinien für Objekte einrichten kann, um die TTL für die Cache-Entfernung zu erzwingen.storage.objects.list: Zum Auflisten vorhandener Cache-Objekte im Bucket.storage.objects.create: Zum Schreiben neuer kompilierter Ausführungsdateien in den Cache.storage.objects.get: Zum Lesen von im Cache gespeicherten Ausführungsdateien aus dem Bucket.
Profilerstellung
Mit dem JAX-Profiler können Sie Traces eines JAX-Programms generieren. Es gibt zwei gängige Methoden, die mit Pathways unterstützt werden:
- Programmatisch
- Profile programmatisch aus Ihrem JAX-Code erfassen
- Manuell
- Profile bei Bedarf erfassen, nachdem der Profiler-Server aus Ihrem JAX-Code gestartet wurde
In beiden Fällen werden die Profile in einen Cloud Storage-Bucket geschrieben. Im Cloud Storage-Bucket werden möglicherweise mehrere Trace-Dateien in verschiedenen Ordnern mit Zeitstempeln erstellt, z. B.:
- Haupt-Python-Prozess, der den Trace aufgerufen hat (in der Regel Ihre Notebook-VM):
<jax-client-vm-name>.xplane.pb - Pathways-IFRT-Proxy:
client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Pathways-Ressourcenmanager:
server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Pathways-Worker:
server.*<tpu-node-name>.xplane.pb
Diese Trace-Dateien können mit TensorBoard analysiert werden. Führen Sie dazu den folgenden Befehl aus. Weitere Informationen zu TensorBoard und allen Profilerstellungstools finden Sie unter TensorFlow-Leistung mit dem Profiler optimieren.
# verify trace files are present gsutil ls -l -r gs://BUCKET/PREFIX # View on tensorboard tensorboard --logdir=gs://BUCKET/PREFIX
Ersetzen Sie Folgendes:
BUCKET: Ein Cloud Storage-Bucket zum Speichern der Trace-DateienPREFIX: Ein Pfad in Ihrem Cloud Storage-Bucket zum Speichern der Trace-Dateien
Programmatische Profilerstellung
Erfassen Sie ein Profil aus Ihrem Code. Die Profile werden in
gs://<bucket>/<prefix> in einem Ordner mit Zeitstempel gespeichert.
import jax import pathwaysutils pathwaysutils.initialize() jax.profiler.start_trace("gs://BUCKET/PREFIX") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
Manuelle Profilerstellung
Wenn Sie Profilinformationen manuell erfassen möchten, müssen Sie den Profiler-Server aus Ihrem Python-Code starten:
import jax
import pathwaysutils
pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)
# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op
Während der Profiler-Server ausgeführt wird, können Sie ein Profil erfassen und die Daten an den Zielspeicherort in Cloud Storage exportieren:
export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>
curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port
Informationen zur Zeitmessung für IFRT-Proxy-Clientmethoden wie Compile und Execute finden Sie im Trace Ihres Programms. Diese Ereignisse, die die Interaktionen mit dem IFRT-Proxy-gRPC-Server während der Kompilierung und Ausführung beschreiben, werden im Thread GrpcClientSessionUserFuturesWorkQueue angezeigt. Wenn Sie diesen Thread in Ihrem Trace untersuchen, können Sie Einblicke in die Leistung dieser Vorgänge erhalten.
XLA-Flags
Wenn Sie Pathways verwenden, müssen Sie die XLA-Flags im Pathways-Proxy-Container festlegen. Sie können dies mit XPK oder der PathwaysJob API tun.
Wenn Sie XPK verwenden, legen Sie XLA-Flags wie folgt fest:
--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"
Wenn Sie die PathwaysJob API verwenden, legen Sie XLA-Flags wie folgt fest:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customFlags: - --xla_flag_1=value1 - --xla_flag_2=value2
Ersetzen Sie Folgendes:
USER: Ihr Google Cloud Nutzernamevalue[n]: Die XLA-Flags, die Sie festlegen möchten
HLO-Dump
Wenn Sie sich genauer mit den HLO-Eingaben (High Level Optimizer) befassen möchten, die an den XLA-Compiler übergeben werden, können Sie Pathways so konfigurieren, dass der HLO wie folgt an einen bestimmten Cloud Storage-Speicherort gesendet wird:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customEnv: - name: XLA_FLAGS value: "--xla_dump_to=gs://your-bucket/your-desired-prefix/"
Nächste Schritte
- GKE-Cluster mit Pathways erstellen
- Inferenz auf mehreren Hosts mit Pathways
- Batch-Arbeitslasten mit Pathways
- Interaktiver Modus von Pathways
- Stabiles Training mit Pathways
- Fehlerbehebung bei Pathways