Portare i carichi di lavoro JAX su Pathways

A causa della natura distribuita di JAX con Pathways, alcune operazioni potrebbero non scalare bene a causa degli overhead di comunicazione. Sebbene Pathways riduca al minimo questi overhead con funzionalità come l'invio asincrono, devi tenere presente alcune cose quando porti i carichi di lavoro JAX su Pathways o quando aumenti la scalabilità di un carico di lavoro JAX con Pathways a un numero elevato di acceleratori.

Prima di iniziare

Assicurati di avere:

Indice di elaborazione

JAX con Pathways considera tutti i dispositivi del cluster Pathways come locali. Ciò semplifica la gestione dei dispositivi e consente a JAX di utilizzare tutte le risorse disponibili. In pratica, questo significa:

  • jax.process_index() è sempre 0 per tutti i dispositivi.
  • jax.devices() e jax.local_devices() restituiscono tutti i dispositivi TPU dell'intero job.

Tipo di hardware e colocation

Per ottenere prestazioni ottimali, inserisci tutti i componenti Pathways e il job utente nella stessa Google Cloud zona cloud. Utilizza una CPU di grandi dimensioni, come il proxy IFRT e Resource Manager. Ti consigliamo almeno un n2-standard-64 dedicato con 64 vCPU e 256 GB di memoria.

PathwaysUtils

Pathways-utils è un repository GitHub basato su Python che fornisce utilità e strumenti essenziali per semplificare il deployment e l'esecuzione dei carichi di lavoro JAX sull'architettura Pathways on Cloud. Questo pacchetto gestisce gli adattamenti necessari per l'ambiente cloud, consentendo agli sviluppatori JAX di concentrarsi sui flussi di lavoro di machine learning di base con una configurazione minima specifica della piattaforma. In particolare, offre:

  • Un backend JAX "proxy": questo backend personalizzato consente alla tua applicazione JAX di utilizzare l'infrastruttura Pathways impostando la variabile di ambiente JAX_PLATFORMS=proxy.
  • Utilità di profilazione integrate: funzionalità di profilazione che consentono di comprendere le prestazioni dell'applicazione. Utilizzando le API di profilazione JAX standard come jax.profiler.start_trace e jax.profiler.start_server, puoi profilare non solo il codice JAX, ma anche i componenti Pathways sottostanti, fornendo una visione olistica dell'esecuzione nell'ambiente cloud.
  • Checkpointing distribuito con Orbax: un gestore di checkpoint Orbax personalizzato che consente di utilizzare i checkpoint distribuiti e ripristinarli quando si utilizza la libreria Orbax nell'ambiente Pathways. Questa integrazione mira a funzionare senza richiedere modifiche al codice di checkpointing Orbax esistente, a condizione che importi pathwaysutils.
  • Primitive di addestramento elastiche: fornisce primitive di addestramento elastiche di base che puoi utilizzare per creare flussi di lavoro di addestramento robusti e scalabili utilizzando Pathways. Queste primitive consentono ai job di addestramento di adattarsi dinamicamente alle modifiche delle risorse disponibili, migliorando l'efficienza e la resilienza negli ambienti cloud.

Checkpointing

Orbax è stato testato a fondo con Pathways per il checkpointing distribuito e il ripristino con Cloud Storage. Quando chiami import pathwaysutils; pathwaysutils.initialize() in train.py, viene registrato un ArrayHandler personalizzato che gestisce in modo efficiente le operazioni di checkpointing tramite il proxy IFRT, consentendo ai worker Pathways sugli acceleratori di salvare e ripristinare direttamente i dati.

Python in colocation

Python in colocation è un'API JAX open source che consente di eseguire il codice Python specificato dall'utente direttamente sugli host TPU o GPU, il che è più semplice in JAX multi-controller JAX. In questo modo, le attività che richiedono un utilizzo intensivo del calcolo, come il caricamento dei dati e il checkpointing, possono evitare il trasferimento di dati tra il client e le macchine TPU. Per configurare il cluster Pathways in modo che esegua l'API JAX Python in colocation, segui le istruzioni riportate nel file README di Python in colocation. Queste istruzioni spiegano come avviare un sidecar Python in colocation insieme ai worker Pathways.

Caricamento dei dati

Durante l'addestramento, carichiamo ripetutamente i batch da un set di dati per inserirli nel modello. È importante disporre di un caricatore di dati asincrono ed efficiente che esegua lo sharding del batch tra gli host per evitare che gli acceleratori non abbiano lavoro. Quando esegui l'addestramento con Pathways, il caricatore di dati viene eseguito su una VM CPU (a differenza di una VM TPU utilizzata nelle configurazioni multi-controller) e invia i dati alle VM TPU. Ciò comporta una latenza maggiore nella lettura dei dati, ma viene mitigata parzialmente leggendo in anticipo X batch sull'host CPU e inviando i dati letti in modo asincrono alle TPU. Questa soluzione è sufficiente quando si esegue su scala da piccola a media.

Per prestazioni ottimali su larga scala, ti consigliamo vivamente di collocare la pipeline di dati di input utilizzando Python in colocation per eseguire la pipeline di dati direttamente su gli acceleratori. In questo modo si elimina il collo di bottiglia della CPU e si sfruttano le interconnessioni veloci della TPU per il trasferimento dei dati.

Puoi trovare un'implementazione di riferimento della migrazione di una pipeline di input basata su TFDS nell'RemoteIterator implementazione in multihost_dataloading.py. Questa implementazione funziona sia su JAX multi-controller sia su Pathways in modo distribuito utilizzando l'API JAX Python in colocation.

Controllo delle versioni di JAX

Le release di Pathways sono strettamente accoppiate alle versioni di JAX per garantire compatibilità e stabilità. Per evitare potenziali problemi, verifica che gli artefatti Pathways e la versione di JAX siano allineati. Ogni release di Pathways specifica chiaramente le versioni di JAX compatibili tramite un tag nel formato jax-<version>.

Cache di compilazione

La cache di compilazione persistente di Pathways è una funzionalità che consente ai server Pathways di archiviare gli eseguibili XLA compilati in una località persistente, ad esempio Cloud Storage, per evitare compilazioni ridondanti. Questa funzione è attivata per impostazione predefinita. La località della cache viene passata come flag --gcs_scratch_location ai container di Resource Manager e worker Pathways. Per ridurre al minimo i costi di archiviazione associati, la cache associa un criterio del ciclo di vita alla località Cloud Storage. Esiste un limite di 50 criteri per bucket Cloud Storage. Pertanto, ti consigliamo di utilizzare una località Cloud Storage comune per tutti i carichi di lavoro.

Questa cache è simile alla cache di compilazione JAX che viene disabilitata da pathwaysutils.initialize() per i carichi di lavoro Pathways.

Per la cache di compilazione sono necessarie le seguenti autorizzazioni Cloud Storage:

  • storage.buckets.get: per recuperare i metadati del bucket.
  • storage.buckets.update: essenziale per Pathways per configurare i criteri del ciclo di vita degli oggetti per applicare il TTL per l'eliminazione della cache.
  • storage.objects.list: per elencare gli oggetti cache esistenti all'interno del bucket.
  • storage.objects.create: per scrivere nuovi eseguibili compilati nella cache.
  • storage.objects.get: per leggere gli eseguibili memorizzati nella cache dal bucket.

Profilazione

Puoi utilizzare il profiler JAX per generare tracce di un programma JAX. Esistono due modi comuni supportati da Pathways:

  • Programmatico
    • Acquisisci i profili in modo programmatico dal codice JAX
  • Manuale
    • Acquisisci i profili on demand dopo aver avviato il server del profiler dal codice JAX

In entrambi i casi, i profili vengono scritti in un bucket Cloud Storage. Nel bucket Cloud Storage verranno creati più file di traccia, potenzialmente in cartelle con timestamp diversi, ad esempio:

  • Processo Python principale che ha richiamato la traccia (in genere la VM del notebook): <jax-client-vm-name>.xplane.pb
  • Proxy IFRT Pathways: client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Resource Manager Pathways: server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Worker Pathways: server.*<tpu-node-name>.xplane.pb

Questi file di traccia possono essere analizzati con TensorBoard eseguendo il seguente comando. Per ulteriori informazioni su TensorBoard e su tutti i relativi strumenti di profilazione, consulta Ottimizzare le prestazioni di TensorFlow utilizzando il profiler.

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

Sostituisci quanto segue:

  • BUCKET : un bucket Cloud Storage in cui archiviare i file di traccia
  • PREFIX: un percorso all'interno del bucket Cloud Storage in cui archiviare i file di traccia

Acquisizione programmatica dei profili

Acquisisci un profilo dall'interno del codice. I profili vengono salvati in gs://<bucket>/<prefix> in una directory con timestamp

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Acquisizione manuale dei profili

Per acquisire manualmente le informazioni del profilo, devi avviare il server del profiler dal codice Python:

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op

Mentre il server del profiler è in esecuzione, puoi acquisire un profilo ed esportare i dati nella località Cloud Storage di destinazione:

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port

Puoi trovare informazioni sui tempi per i metodi client del proxy IFRT, come Compile ed Execute, all'interno della traccia del programma. Questi eventi, che descrivono in dettaglio le interazioni con il server gRPC del proxy IFRT durante la compilazione e l'esecuzione, vengono visualizzati nel thread denominato GrpcClientSessionUserFuturesWorkQueue. Esaminando questo thread nella traccia, puoi ottenere informazioni sulle prestazioni di queste operazioni.

Flag XLA

Quando utilizzi Pathways, devi impostare i flag XLA nel container pathways-proxy. Puoi farlo utilizzando XPK o l'API PathwaysJob.

Quando utilizzi XPK, imposta i flag XLA come segue:

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

Quando utilizzi l'API PathwaysJob, imposta i flag XLA come segue:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

Sostituisci quanto segue:

  • USER : il tuo Google Cloud nome utente
  • value[n]: i flag XLA che vuoi impostare

Dump HLO

Per approfondire gli input di High Level Optimizer (HLO) forniti al compilatore XLA, puoi configurare Pathways in modo che esegua il dump di HLO in una località Cloud Storage specificata come segue:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-bucket/your-desired-prefix/"

Passaggi successivi