Aufgrund der verteilten Natur von JAX mit Pathways lassen sich einige Vorgänge aufgrund von Kommunikations-Overheads möglicherweise nicht gut skalieren. Pathways minimiert diesen Overhead mit Funktionen wie dem asynchronen Dispatch. Es gibt jedoch einige Dinge, die Sie beachten müssen, wenn Sie JAX-Arbeitslasten zu Pathways portieren oder eine JAX-Arbeitslast mit Pathways auf eine große Anzahl von Beschleunigern skalieren.
Hinweise
Sie benötigen Folgendes:
- Installierte Kubernetes-Tools
- gcloud CLI installiert
- TPU API aktiviert
- Google Kubernetes Engine API aktiviert
Prozessindex
Bei JAX mit Pathways werden alle Geräte in Ihrem Pathways-Cluster als lokal behandelt. Das vereinfacht die Geräteverwaltung und ermöglicht es JAX, alle verfügbaren Ressourcen zu nutzen. In der Praxis bedeutet das:
jax.process_index()ist für alle Geräte immer 0.jax.devices()undjax.local_devices()geben alle TPU-Geräte für den gesamten Job zurück.
Hardwaretyp und Colocation
Für eine optimale Leistung sollten sich alle Pathways-Komponenten und der Nutzerjob in derselben Google Cloud Cloud-Zone befinden. Verwenden Sie eine große CPU wie den IFRT-Proxy und den Ressourcenmanager. Wir empfehlen mindestens eine dedizierte n2-standard-64 mit 64 vCPUs und 256 GB Arbeitsspeicher.
PathwaysUtils
Pathways-utils ist ein Python-basiertes GitHub-Repository mit wichtigen Dienstprogrammen und Tools, mit denen Sie die Bereitstellung und Ausführung von JAX-Arbeitslasten in der Pathways on Cloud-Architektur optimieren können. Dieses Paket übernimmt die erforderlichen Anpassungen für die Cloud-Umgebung, sodass sich JAX-Entwickler mit minimaler plattformspezifischer Konfiguration auf ihre wichtigsten Machine-Learning-Workflows konzentrieren können. Konkret bietet sie Folgendes:
- Ein JAX-Backend vom Typ „Proxy“: Mit diesem benutzerdefinierten Backend kann Ihre JAX-Anwendung die Pathways-Infrastruktur nutzen, indem Sie die Umgebungsvariable
JAX_PLATFORMS=proxyfestlegen. - Integrierte Profiling-Tools: Profiling-Funktionen, mit denen Sie die Leistung Ihrer Anwendung nachvollziehen können. Mit Standard-JAX-Profiling-APIs wie
jax.profiler.start_traceundjax.profiler.start_serverkönnen Sie nicht nur Ihren JAX-Code, sondern auch die zugrunde liegenden Pathways-Komponenten profilieren. So erhalten Sie einen ganzheitlichen Überblick über die Ausführung in der Cloud-Umgebung. - Verteilte Prüfpunkte mit Orbax: Ein benutzerdefinierter Orbax-Prüfpunkthandler, mit dem Sie verteilte Prüfpunkte verwenden und Ihre Prüfpunkte wiederherstellen können, wenn Sie die Orbax-Bibliothek in der Pathways-Umgebung verwenden. Diese Integration soll ohne Änderungen an Ihrem vorhandenen Orbax-Checkpointing-Code funktionieren, sofern
pathwaysutilsimportiert wird. - Elastische Trainingsprimitive: Bietet grundlegende elastische Trainingsprimitive, mit denen Sie robuste und skalierbare Trainingsworkflows mit Pathways erstellen können. Mit diesen Primitiven können Ihre Trainingsjobs sich dynamisch an Änderungen bei den verfügbaren Ressourcen anpassen, was die Effizienz und Ausfallsicherheit in Cloud-Umgebungen verbessert.
Prüfpunkte
Orbax wird mit Pathways für verteiltes Erstellen und Wiederherstellen von Prüfpunkten mit Cloud Storage gründlich getestet. Wenn Sie import pathwaysutils; pathwaysutils.initialize() in Ihrem train.py aufrufen, wird ein benutzerdefinierter ArrayHandler registriert, der Prüfpunktvorgänge effizient über den IFRT-Proxy verarbeitet. Dadurch können Pathways-Worker auf Beschleunigern Daten direkt speichern und wiederherstellen.
Colocated Python
Colocated Python ist eine Open-Source-JAX-API, mit der Sie benutzerdefinierten Python-Code direkt auf den TPU- oder GPU-Hosts ausführen können. Das ist in JAX mit mehreren Controllern einfacher. So können rechenintensive Aufgaben wie das Laden von Daten und das Erstellen von Checkpoints vermieden werden, da keine Daten zwischen dem Client und den TPU-Maschinen übertragen werden müssen. Wenn Sie Ihren Pathways-Cluster so konfigurieren möchten, dass die Python-JAX-API am selben Ort ausgeführt wird, folgen Sie der Anleitung in der README-Datei für Python am selben Ort. Dort wird beschrieben, wie Sie einen Python-Sidecar am selben Ort wie Ihre Pathways-Worker starten.
Laden der Daten
Während des Trainings werden wiederholt Batches aus einem Dataset geladen, um sie in das Modell einzuspeisen. Ein effizienter, asynchroner Data Loader, der den Batch auf Hosts verteilt, ist wichtig, um unterausgelastete Beschleuniger zu vermeiden. Wenn Sie das Training mit Pathways ausführen, wird der Daten-Loader auf einer CPU-VM ausgeführt (im Gegensatz zu einer TPU-VM, die in Setups mit mehreren Controllern verwendet wird) und sendet Daten an TPU-VMs. Dies führt zu einer höheren Latenz beim Lesen von Daten. Diese wird jedoch teilweise dadurch gemindert, dass auf dem CPU-Host X Batches im Voraus gelesen und die gelesenen Daten asynchron an die TPUs gesendet werden. Diese Lösung ist ausreichend, wenn Sie sie in kleinem bis mittlerem Umfang einsetzen.
Für eine optimale Leistung im großen Maßstab empfehlen wir dringend, Ihre Eingabedatenpipeline mit colocated Python zu platzieren, um Ihre Datenpipeline direkt auf den Beschleunigern auszuführen. Dadurch wird der CPU-Engpass beseitigt und die schnellen Verbindungen der TPU für die Datenübertragung genutzt.
Eine Referenzimplementierung für die Migration einer TFDS-basierten Eingabepipeline finden Sie in der RemoteIterator-Implementierung in multihost_dataloading.py.
Diese Implementierung funktioniert sowohl für JAX mit mehreren Controllern als auch für Pathways auf verteilte Weise mit der Python-JAX-API.
Jax-Versionsverwaltung
Pathways-Releases sind eng mit JAX-Versionen verknüpft, um Kompatibilität und Stabilität zu gewährleisten. Um potenzielle Probleme zu vermeiden, sollten Sie darauf achten, dass Ihre Pathways-Artefakte und Ihre JAX-Version aufeinander abgestimmt sind. In jedem Pathways-Release werden die kompatiblen JAX-Versionen durch ein Tag der Form jax-<version> angegeben.
Kompilierungs-Cache
Der persistente Kompilierungscache von Pathways ist eine Funktion, mit der Pathways-Server kompilierte XLA-Ausführungsdateien an einem persistenten Speicherort wie Cloud Storage speichern können, um redundante Kompilierungen zu vermeiden. Diese Funktion ist standardmäßig aktiviert. Der Speicherort des Cache wird als --gcs_scratch_location-Flag an die Ressourcenmanager- und Pathways-Worker-Container übergeben. Um die zugehörigen Speicherkosten so gering wie möglich zu halten, wird dem Cloud Storage-Speicherort eine Lebenszyklusrichtlinie zugewiesen. Pro Cloud Storage-Bucket gilt ein Limit von 50 Richtlinien. Daher empfehlen wir, für alle Arbeitslasten einen gemeinsamen Cloud Storage-Speicherort zu verwenden.
Dieser Cache ähnelt dem JAX-Kompilierungscache, der für Pathways-Arbeitslasten durch pathwaysutils.initialize() deaktiviert wird.
Profilerstellung
Mit dem JAX-Profiler können Sie Traces eines JAX-Programms generieren. Es gibt zwei gängige Methoden, die mit Pathways unterstützt werden:
- Programmatisch
- Profile programmatisch aus Ihrem JAX-Code erfassen
- Manuell
- Profile bei Bedarf erfassen, nachdem der Profiler-Server über Ihren JAX-Code gestartet wurde
In beiden Fällen werden die Profile in einen Cloud Storage-Bucket geschrieben. Im Cloud Storage-Bucket werden mehrere Tracedateien erstellt, möglicherweise in verschiedenen Zeitstempelordnern, z. B.:
- Haupt-Python-Prozess, der den Trace aufgerufen hat (in der Regel Ihre Notebook-VM):
<jax-client-vm-name>.xplane.pb - IFRT-Proxy für Pathways:
client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Pathways Resource Manager:
server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Pathways-Mitarbeiter:
server.*<tpu-node-name>.xplane.pb
Diese Tracedateien können mit TensorBoard analysiert werden. Führen Sie dazu den folgenden Befehl aus. Weitere Informationen zu TensorBoard und allen Profiler-Tools finden Sie unter TensorFlow-Leistung mit dem Profiler optimieren.
# verify trace files are present gsutil ls -l -r gs://BUCKET/PREFIX # View on tensorboard tensorboard --logdir=gs://BUCKET/PREFIX
Ersetzen Sie Folgendes:
BUCKET: Ein Cloud Storage-Bucket zum Speichern der Tracedateien.PREFIX: Ein Pfad in Ihrem Cloud Storage-Bucket zum Speichern der Tracedateien
Programmatische Profilerfassung
Erstellen Sie ein Profil aus Ihrem Code heraus. Die Profile werden in gs://<bucket>/<prefix> in einem Verzeichnis mit Zeitstempel gespeichert.
import jax import pathwaysutils pathwaysutils.initialize() jax.profiler.start_trace("gs://BUCKET/PREFIX") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
Manuelle Profilerfassung
Wenn Sie Profilinformationen manuell erfassen möchten, müssen Sie den Profiler-Server über Ihren Python-Code starten:
import jax
import pathwaysutils
pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)
# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op
Während der Profiler-Server ausgeführt wird, können Sie ein Profil erfassen und die Daten an den Zielspeicherort in Cloud Storage exportieren:
export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>
curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port
Zeitinformationen für IFRT-Proxy-Client-Methoden wie Compile und Execute finden Sie im Trace Ihres Programms. Diese Ereignisse, die die Interaktionen mit dem IFRT Proxy-gRPC-Server während der Kompilierung und Ausführung beschreiben, werden im Thread mit dem Namen GrpcClientSessionUserFuturesWorkQueue angezeigt. Wenn Sie sich diesen Thread in Ihrem Trace ansehen, erhalten Sie Einblicke in die Leistung dieser Vorgänge.
XLA-Flags
Wenn Sie Pathways verwenden, müssen Sie die XLA-Flags im Container „pathways-proxy“ festlegen. Sie können dazu XPK oder die PathwaysJob API verwenden.
Wenn Sie XPK verwenden, legen Sie XLA-Flags so fest:
--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"
Wenn Sie die PathwaysJob API verwenden, legen Sie XLA-Flags so fest:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customFlags: - --xla_flag_1=value1 - --xla_flag_2=value2
Ersetzen Sie Folgendes:
USER: Ihr Google Cloud Nutzernamevalue[n]: die XLA-Flags, die Sie festlegen möchten
HLO-Dump
Wenn Sie sich die HLO-Eingaben (High Level Optimizer) ansehen möchten, die an den XLA-Compiler übergeben werden, können Sie Pathways so konfigurieren, dass die HLO-Eingaben an einem angegebenen Cloud Storage-Speicherort gespeichert werden:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customEnv: - name: XLA_FLAGS value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"
Nächste Schritte
- GKE-Cluster mit Pathways erstellen
- Inferenz auf mehreren Hosts mit Pathways
- Batch-Arbeitslasten mit Pathways
- Interaktiver Modus für Lernpfade
- Belastbares Training mit Pathways
- Pfade zur Fehlerbehebung