JAX-Arbeitslasten zu Pathways portieren

Aufgrund der verteilten Natur von JAX mit Pathways lassen sich einige Vorgänge aufgrund von Kommunikations-Overheads möglicherweise nicht gut skalieren. Pathways minimiert diesen Overhead mit Funktionen wie dem asynchronen Dispatch. Es gibt jedoch einige Dinge, die Sie beachten müssen, wenn Sie JAX-Arbeitslasten zu Pathways portieren oder eine JAX-Arbeitslast mit Pathways auf eine große Anzahl von Beschleunigern skalieren.

Hinweise

Sie benötigen Folgendes:

Prozessindex

Bei JAX mit Pathways werden alle Geräte in Ihrem Pathways-Cluster als lokal behandelt. Das vereinfacht die Geräteverwaltung und ermöglicht es JAX, alle verfügbaren Ressourcen zu nutzen. In der Praxis bedeutet das:

  • jax.process_index() ist für alle Geräte immer 0.
  • jax.devices() und jax.local_devices() geben alle TPU-Geräte für den gesamten Job zurück.

Hardwaretyp und Colocation

Für eine optimale Leistung sollten sich alle Pathways-Komponenten und der Nutzerjob in derselben Google Cloud Cloud-Zone befinden. Verwenden Sie eine große CPU wie den IFRT-Proxy und den Ressourcenmanager. Wir empfehlen mindestens eine dedizierte n2-standard-64 mit 64 vCPUs und 256 GB Arbeitsspeicher.

PathwaysUtils

Pathways-utils ist ein Python-basiertes GitHub-Repository mit wichtigen Dienstprogrammen und Tools, mit denen Sie die Bereitstellung und Ausführung von JAX-Arbeitslasten in der Pathways on Cloud-Architektur optimieren können. Dieses Paket übernimmt die erforderlichen Anpassungen für die Cloud-Umgebung, sodass sich JAX-Entwickler mit minimaler plattformspezifischer Konfiguration auf ihre wichtigsten Machine-Learning-Workflows konzentrieren können. Konkret bietet sie Folgendes:

  • Ein JAX-Backend vom Typ „Proxy“: Mit diesem benutzerdefinierten Backend kann Ihre JAX-Anwendung die Pathways-Infrastruktur nutzen, indem Sie die Umgebungsvariable JAX_PLATFORMS=proxy festlegen.
  • Integrierte Profiling-Tools: Profiling-Funktionen, mit denen Sie die Leistung Ihrer Anwendung nachvollziehen können. Mit Standard-JAX-Profiling-APIs wie jax.profiler.start_trace und jax.profiler.start_server können Sie nicht nur Ihren JAX-Code, sondern auch die zugrunde liegenden Pathways-Komponenten profilieren. So erhalten Sie einen ganzheitlichen Überblick über die Ausführung in der Cloud-Umgebung.
  • Verteilte Prüfpunkte mit Orbax: Ein benutzerdefinierter Orbax-Prüfpunkthandler, mit dem Sie verteilte Prüfpunkte verwenden und Ihre Prüfpunkte wiederherstellen können, wenn Sie die Orbax-Bibliothek in der Pathways-Umgebung verwenden. Diese Integration soll ohne Änderungen an Ihrem vorhandenen Orbax-Checkpointing-Code funktionieren, sofern pathwaysutils importiert wird.
  • Elastische Trainingsprimitive: Bietet grundlegende elastische Trainingsprimitive, mit denen Sie robuste und skalierbare Trainingsworkflows mit Pathways erstellen können. Mit diesen Primitiven können Ihre Trainingsjobs sich dynamisch an Änderungen bei den verfügbaren Ressourcen anpassen, was die Effizienz und Ausfallsicherheit in Cloud-Umgebungen verbessert.

Prüfpunkte

Orbax wird mit Pathways für verteiltes Erstellen und Wiederherstellen von Prüfpunkten mit Cloud Storage gründlich getestet. Wenn Sie import pathwaysutils; pathwaysutils.initialize() in Ihrem train.py aufrufen, wird ein benutzerdefinierter ArrayHandler registriert, der Prüfpunktvorgänge effizient über den IFRT-Proxy verarbeitet. Dadurch können Pathways-Worker auf Beschleunigern Daten direkt speichern und wiederherstellen.

Colocated Python

Colocated Python ist eine Open-Source-JAX-API, mit der Sie benutzerdefinierten Python-Code direkt auf den TPU- oder GPU-Hosts ausführen können. Das ist in JAX mit mehreren Controllern einfacher. So können rechenintensive Aufgaben wie das Laden von Daten und das Erstellen von Checkpoints vermieden werden, da keine Daten zwischen dem Client und den TPU-Maschinen übertragen werden müssen. Wenn Sie Ihren Pathways-Cluster so konfigurieren möchten, dass die Python-JAX-API am selben Ort ausgeführt wird, folgen Sie der Anleitung in der README-Datei für Python am selben Ort. Dort wird beschrieben, wie Sie einen Python-Sidecar am selben Ort wie Ihre Pathways-Worker starten.

Laden der Daten

Während des Trainings werden wiederholt Batches aus einem Dataset geladen, um sie in das Modell einzuspeisen. Ein effizienter, asynchroner Data Loader, der den Batch auf Hosts verteilt, ist wichtig, um unterausgelastete Beschleuniger zu vermeiden. Wenn Sie das Training mit Pathways ausführen, wird der Daten-Loader auf einer CPU-VM ausgeführt (im Gegensatz zu einer TPU-VM, die in Setups mit mehreren Controllern verwendet wird) und sendet Daten an TPU-VMs. Dies führt zu einer höheren Latenz beim Lesen von Daten. Diese wird jedoch teilweise dadurch gemindert, dass auf dem CPU-Host X Batches im Voraus gelesen und die gelesenen Daten asynchron an die TPUs gesendet werden. Diese Lösung ist ausreichend, wenn Sie sie in kleinem bis mittlerem Umfang einsetzen.

Für eine optimale Leistung im großen Maßstab empfehlen wir dringend, Ihre Eingabedatenpipeline mit colocated Python zu platzieren, um Ihre Datenpipeline direkt auf den Beschleunigern auszuführen. Dadurch wird der CPU-Engpass beseitigt und die schnellen Verbindungen der TPU für die Datenübertragung genutzt.

Eine Referenzimplementierung für die Migration einer TFDS-basierten Eingabepipeline finden Sie in der RemoteIterator-Implementierung in multihost_dataloading.py. Diese Implementierung funktioniert sowohl für JAX mit mehreren Controllern als auch für Pathways auf verteilte Weise mit der Python-JAX-API.

Jax-Versionsverwaltung

Pathways-Releases sind eng mit JAX-Versionen verknüpft, um Kompatibilität und Stabilität zu gewährleisten. Um potenzielle Probleme zu vermeiden, sollten Sie darauf achten, dass Ihre Pathways-Artefakte und Ihre JAX-Version aufeinander abgestimmt sind. In jedem Pathways-Release werden die kompatiblen JAX-Versionen durch ein Tag der Form jax-<version> angegeben.

Kompilierungs-Cache

Der persistente Kompilierungscache von Pathways ist eine Funktion, mit der Pathways-Server kompilierte XLA-Ausführungsdateien an einem persistenten Speicherort wie Cloud Storage speichern können, um redundante Kompilierungen zu vermeiden. Diese Funktion ist standardmäßig aktiviert. Der Speicherort des Cache wird als --gcs_scratch_location-Flag an die Ressourcenmanager- und Pathways-Worker-Container übergeben. Um die zugehörigen Speicherkosten so gering wie möglich zu halten, wird dem Cloud Storage-Speicherort eine Lebenszyklusrichtlinie zugewiesen. Pro Cloud Storage-Bucket gilt ein Limit von 50 Richtlinien. Daher empfehlen wir, für alle Arbeitslasten einen gemeinsamen Cloud Storage-Speicherort zu verwenden.

Dieser Cache ähnelt dem JAX-Kompilierungscache, der für Pathways-Arbeitslasten durch pathwaysutils.initialize() deaktiviert wird.

Profilerstellung

Mit dem JAX-Profiler können Sie Traces eines JAX-Programms generieren. Es gibt zwei gängige Methoden, die mit Pathways unterstützt werden:

  • Programmatisch
    • Profile programmatisch aus Ihrem JAX-Code erfassen
  • Manuell
    • Profile bei Bedarf erfassen, nachdem der Profiler-Server über Ihren JAX-Code gestartet wurde

In beiden Fällen werden die Profile in einen Cloud Storage-Bucket geschrieben. Im Cloud Storage-Bucket werden mehrere Tracedateien erstellt, möglicherweise in verschiedenen Zeitstempelordnern, z. B.:

  • Haupt-Python-Prozess, der den Trace aufgerufen hat (in der Regel Ihre Notebook-VM): <jax-client-vm-name>.xplane.pb
  • IFRT-Proxy für Pathways: client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pathways Resource Manager: server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pathways-Mitarbeiter: server.*<tpu-node-name>.xplane.pb

Diese Tracedateien können mit TensorBoard analysiert werden. Führen Sie dazu den folgenden Befehl aus. Weitere Informationen zu TensorBoard und allen Profiler-Tools finden Sie unter TensorFlow-Leistung mit dem Profiler optimieren.

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

Ersetzen Sie Folgendes:

  • BUCKET : Ein Cloud Storage-Bucket zum Speichern der Tracedateien.
  • PREFIX: Ein Pfad in Ihrem Cloud Storage-Bucket zum Speichern der Tracedateien

Programmatische Profilerfassung

Erstellen Sie ein Profil aus Ihrem Code heraus. Die Profile werden in gs://<bucket>/<prefix> in einem Verzeichnis mit Zeitstempel gespeichert.

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Manuelle Profilerfassung

Wenn Sie Profilinformationen manuell erfassen möchten, müssen Sie den Profiler-Server über Ihren Python-Code starten:

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op

Während der Profiler-Server ausgeführt wird, können Sie ein Profil erfassen und die Daten an den Zielspeicherort in Cloud Storage exportieren:

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port

Zeitinformationen für IFRT-Proxy-Client-Methoden wie Compile und Execute finden Sie im Trace Ihres Programms. Diese Ereignisse, die die Interaktionen mit dem IFRT Proxy-gRPC-Server während der Kompilierung und Ausführung beschreiben, werden im Thread mit dem Namen GrpcClientSessionUserFuturesWorkQueue angezeigt. Wenn Sie sich diesen Thread in Ihrem Trace ansehen, erhalten Sie Einblicke in die Leistung dieser Vorgänge.

XLA-Flags

Wenn Sie Pathways verwenden, müssen Sie die XLA-Flags im Container „pathways-proxy“ festlegen. Sie können dazu XPK oder die PathwaysJob API verwenden.

Wenn Sie XPK verwenden, legen Sie XLA-Flags so fest:

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

Wenn Sie die PathwaysJob API verwenden, legen Sie XLA-Flags so fest:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

Ersetzen Sie Folgendes:

  • USER : Ihr Google Cloud Nutzername
  • value[n]: die XLA-Flags, die Sie festlegen möchten

HLO-Dump

Wenn Sie sich die HLO-Eingaben (High Level Optimizer) ansehen möchten, die an den XLA-Compiler übergeben werden, können Sie Pathways so konfigurieren, dass die HLO-Eingaben an einem angegebenen Cloud Storage-Speicherort gespeichert werden:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"

Nächste Schritte