Portare i carichi di lavoro JAX su Pathways

A causa della natura distribuita di JAX con Pathways, alcune operazioni potrebbero non scalare bene a causa dei sovraccarichi di comunicazione. Sebbene Pathways riduca al minimo questi overhead con funzionalità come l'invio asincrono, ci sono alcune cose di cui devi essere a conoscenza quando porti i carichi di lavoro JAX su Pathways o quando ridimensioni un carico di lavoro JAX con Pathways a un numero elevato di acceleratori.

Prima di iniziare

Assicurati di avere:

Indice di elaborazione

JAX con Pathways considera tutti i dispositivi del cluster Pathways come locali. Ciò semplifica la gestione dei dispositivi e consente a JAX di utilizzare tutte le risorse disponibili. In pratica, ciò significa che:

  • jax.process_index() è sempre 0 per tutti i dispositivi.
  • jax.devices() e jax.local_devices() restituiscono tutti i dispositivi TPU nell'intero job.

Tipo di hardware e colocation

Per ottenere le prestazioni migliori, posiziona tutti i componenti di Percorsi e il job utente nella stessa Google Cloud zona cloud. Utilizza una CPU di grandi dimensioni come il proxy IFRT e Resource Manager. Ti consigliamo almeno un n2-standard-64 dedicato, dotato di 64 vCPU e 256 GB di memoria.

PathwaysUtils

Pathways-utils è un repository GitHub basato su Python che fornisce utilità e strumenti essenziali per semplificare il deployment e l'esecuzione dei carichi di lavoro JAX nell'architettura Pathways on Cloud. Questo pacchetto gestisce gli adattamenti necessari per l'ambiente cloud, consentendo agli sviluppatori JAX di concentrarsi sui flussi di lavoro principali di machine learning con una configurazione specifica della piattaforma minima. Nello specifico, offre:

  • Un backend JAX "proxy": questo backend personalizzato consente alla tua applicazione JAX di utilizzare l'infrastruttura Pathways impostando la variabile di ambiente JAX_PLATFORMS=proxy.
  • Utilità di profilazione integrate: funzionalità di profilazione che ti consentono di comprendere le prestazioni della tua applicazione. Utilizzando le API di profilazione JAX standard come jax.profiler.start_trace e jax.profiler.start_server, puoi profilare non solo il codice JAX, ma anche i componenti Pathways sottostanti, fornendo una visione olistica dell'esecuzione nell'ambiente cloud.
  • Checkpointing distribuito con Orbax: un gestore di checkpoint Orbax personalizzato che ti consente di utilizzare i checkpoint distribuiti e ripristinarli quando utilizzi la libreria Orbax nell'ambiente Pathways. Questa integrazione ha lo scopo di funzionare senza richiedere modifiche al codice di checkpointing Orbax esistente purché importi pathwaysutils.
  • Primitive di addestramento elastiche: fornisce primitive di addestramento elastiche di base che puoi utilizzare per creare workflow di addestramento scalabili e robusti utilizzando Pathways. Queste primitive consentono ai job di addestramento di adattarsi dinamicamente alle modifiche delle risorse disponibili, migliorando l'efficienza e la resilienza negli ambienti cloud.

Checkpoint

Orbax è stato testato a fondo con Pathways per il checkpointing e il ripristino distribuiti con Cloud Storage. Quando chiami import pathwaysutils; pathwaysutils.initialize() in train.py, viene registrato un ArrayHandler personalizzato che gestisce in modo efficiente le operazioni di checkpoint tramite il proxy IFRT, consentendo ai worker di Pathways sugli acceleratori di salvare e ripristinare direttamente i dati.

Python colocalizzato

Python colocalizzato è un'API JAX open source che consente di eseguire codice Python specificato dall'utente direttamente sugli host TPU o GPU, il che è più semplice in JAX multi-controller. Ciò consente di evitare il trasferimento di dati tra il client e le macchine TPU per attività che richiedono un maggiore utilizzo di risorse di calcolo, come il caricamento dei dati e il checkpointing. Per configurare il cluster Pathways per eseguire l'API Python JAX collocate, segui le istruzioni nel file README di Python collocato. Queste istruzioni spiegano come avviare un sidecar Python collocato insieme ai worker Pathways.

Caricamento dei dati

Durante l'addestramento, carichiamo ripetutamente batch da un set di dati per inserirli nel modello. Disporre di un caricatore di dati asincrono efficiente che distribuisca il batch tra gli host è importante per evitare di privare gli acceleratori di lavoro. Quando esegui l'addestramento con Pathways, il caricatore dei dati viene eseguito su una VM CPU (a differenza di una VM TPU che viene utilizzata nelle configurazioni multi-controller) e invia i dati alle VM TPU. Ciò comporta una latenza maggiore nella lettura dei dati, ma questo problema viene parzialmente mitigato leggendo in anticipo X batch sull'host CPU e inviando i dati letti in modo asincrono alle TPU. Questa soluzione è sufficiente quando viene eseguita su scala ridotta o media.

Per prestazioni ottimali su larga scala, ti consigliamo vivamente di collocare la pipeline di dati di input utilizzando Python colocalizzato per eseguire la pipeline di dati direttamente sugli acceleratori. In questo modo si elimina il collo di bottiglia della CPU e si sfruttano le interconnessioni veloci della TPU per il trasferimento dei dati.

Puoi trovare un'implementazione di riferimento della migrazione di una pipeline di input basata su TFDS nell'implementazione di RemoteIterator in multihost_dataloading.py. Questa implementazione funziona sia su JAX multi-controller sia su Pathways in modo distribuito utilizzando l'API Python JAX collocate.

Controllo delle versioni di Jax

Le release di Pathways sono strettamente accoppiate alle versioni di JAX per garantire compatibilità e stabilità. Per evitare potenziali problemi, verifica che gli artefatti di Pathways e la tua versione di JAX siano allineati. Ogni release di Pathways specifica chiaramente le versioni di JAX compatibili tramite un tag nel formato jax-<version>.

Cache di compilazione

La cache di compilazione persistente di Pathways è una funzionalità che consente ai server Pathways di archiviare gli eseguibili XLA compilati in una posizione persistente, ad esempio Cloud Storage, per evitare compilazioni ridondanti. Questa funzionalità è attivata per impostazione predefinita. La posizione della cache viene passata come flag --gcs_scratch_location ai container di Resource Manager e Pathways Worker. Per ridurre al minimo i costi di archiviazione associati, la cache collega un criterio del ciclo di vita alla posizione Cloud Storage. Esiste un limite di 50 policy per bucket Cloud Storage. Pertanto, ti consigliamo di utilizzare una posizione Cloud Storage comune in tutti i workload.

Questa cache è simile alla cache di compilazione JAX che è disabilitata da pathwaysutils.initialize() per i carichi di lavoro Pathways.

Profilazione

Puoi utilizzare JAX Profiler per generare tracce di un programma JAX. Esistono due metodi comuni supportati da Pathways:

  • Pubblicità programmatica
    • Acquisire i profili in modo programmatico dal codice JAX
  • Manuale
    • Acquisizione di profili on demand dopo l'avvio del server Profiler dal codice JAX

In entrambi i casi, i profili vengono scritti in un bucket Cloud Storage. Nel bucket Cloud Storage verranno creati più file di traccia, potenzialmente in cartelle con timestamp diversi, ad esempio:

  • Processo Python principale che ha richiamato la traccia (in genere la VM del notebook): <jax-client-vm-name>.xplane.pb
  • Proxy IFRT di Pathways: client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Resource Manager di Pathways: server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Operatore/i di Pathways: server.*<tpu-node-name>.xplane.pb

Questi file di traccia possono essere analizzati con TensorBoard eseguendo il seguente comando. Per saperne di più su TensorBoard e su tutti i suoi strumenti di profilazione, consulta Ottimizzare il rendimento di TensorFlow utilizzando Profiler.

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

Sostituisci quanto segue:

  • BUCKET : un bucket Cloud Storage in cui archiviare i file di traccia
  • PREFIX: un percorso all'interno del bucket Cloud Storage in cui archiviare i file di traccia

Acquisizione programmatica dei profili

Acquisisci un profilo dall'interno del codice. I profili vengono salvati all'interno di gs://<bucket>/<prefix> in una directory con timestamp

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Acquisizione manuale del profilo

Per acquisire manualmente le informazioni del profilo, devi avviare il server del profiler dal codice Python:

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op

Mentre il server Profiler è in esecuzione, puoi acquisire un profilo ed esportare i dati nella posizione di destinazione di Cloud Storage:

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port

Puoi trovare informazioni sui tempi per i metodi del client proxy IFRT come Compile e Execute all'interno della traccia del tuo programma. Questi eventi, che descrivono in dettaglio le interazioni con il server gRPC proxy IFRT durante la compilazione e l'esecuzione, vengono visualizzati nel thread denominato GrpcClientSessionUserFuturesWorkQueue. Esaminando questo thread nella traccia, puoi ottenere informazioni sul rendimento di queste operazioni.

Flag XLA

Quando utilizzi Pathways, devi impostare i flag XLA nel container pathways-proxy. Puoi farlo utilizzando XPK o l'API PathwaysJob.

Quando utilizzi XPK, imposta i flag XLA come segue:

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

Quando utilizzi l'API PathwaysJob, imposta i flag XLA come segue:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

Sostituisci quanto segue:

  • USER : il tuo Google Cloud nome utente
  • value[n]: i flag XLA che vuoi impostare

Dump HLO

Per analizzare nel dettaglio gli input di High Level Optimizer (HLO) forniti al compilatore XLA, puoi configurare Pathways per eseguire il dump di HLO in una posizione Cloud Storage specificata nel seguente modo:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"

Passaggi successivi