A causa della natura distribuita di JAX con Pathways, alcune operazioni potrebbero non scalare bene a causa dei sovraccarichi di comunicazione. Sebbene Pathways riduca al minimo questi overhead con funzionalità come l'invio asincrono, ci sono alcune cose di cui devi essere a conoscenza quando porti i carichi di lavoro JAX su Pathways o quando ridimensioni un carico di lavoro JAX con Pathways a un numero elevato di acceleratori.
Prima di iniziare
Assicurati di avere:
- Strumenti Kubernetes installati
- Installazione di gcloud CLI
- Abilitato l'API TPU
- Abilitato l'API Google Kubernetes Engine
Indice di elaborazione
JAX con Pathways considera tutti i dispositivi del cluster Pathways come locali. Ciò semplifica la gestione dei dispositivi e consente a JAX di utilizzare tutte le risorse disponibili. In pratica, ciò significa che:
jax.process_index()è sempre 0 per tutti i dispositivi.jax.devices()ejax.local_devices()restituiscono tutti i dispositivi TPU nell'intero job.
Tipo di hardware e colocation
Per ottenere le prestazioni migliori, posiziona tutti i componenti di Percorsi e il job utente nella
stessa Google Cloud zona cloud. Utilizza una CPU di grandi dimensioni come il proxy IFRT e Resource Manager. Ti consigliamo almeno un n2-standard-64 dedicato, dotato di 64 vCPU e 256 GB di memoria.
PathwaysUtils
Pathways-utils è un repository GitHub basato su Python che fornisce utilità e strumenti essenziali per semplificare il deployment e l'esecuzione dei carichi di lavoro JAX nell'architettura Pathways on Cloud. Questo pacchetto gestisce gli adattamenti necessari per l'ambiente cloud, consentendo agli sviluppatori JAX di concentrarsi sui flussi di lavoro principali di machine learning con una configurazione specifica della piattaforma minima. Nello specifico, offre:
- Un backend JAX "proxy": questo backend personalizzato consente alla tua applicazione JAX di
utilizzare l'infrastruttura Pathways impostando la variabile di ambiente
JAX_PLATFORMS=proxy. - Utilità di profilazione integrate: funzionalità di profilazione che ti consentono di comprendere
le prestazioni della tua applicazione. Utilizzando le API di profilazione JAX standard come
jax.profiler.start_traceejax.profiler.start_server, puoi profilare non solo il codice JAX, ma anche i componenti Pathways sottostanti, fornendo una visione olistica dell'esecuzione nell'ambiente cloud. - Checkpointing distribuito con Orbax: un gestore di checkpoint Orbax personalizzato che
ti consente di utilizzare i checkpoint distribuiti e ripristinarli quando utilizzi
la libreria Orbax nell'ambiente Pathways. Questa integrazione ha lo scopo di
funzionare senza richiedere modifiche al codice di checkpointing Orbax esistente
purché importi
pathwaysutils. - Primitive di addestramento elastiche: fornisce primitive di addestramento elastiche di base che puoi utilizzare per creare workflow di addestramento scalabili e robusti utilizzando Pathways. Queste primitive consentono ai job di addestramento di adattarsi dinamicamente alle modifiche delle risorse disponibili, migliorando l'efficienza e la resilienza negli ambienti cloud.
Checkpoint
Orbax è stato testato a fondo con Pathways per il checkpointing e il ripristino distribuiti con Cloud Storage. Quando chiami
import pathwaysutils; pathwaysutils.initialize() in train.py, viene registrato un
ArrayHandler personalizzato che gestisce in modo efficiente le operazioni di checkpoint
tramite il proxy IFRT, consentendo ai worker di Pathways sugli acceleratori di salvare e ripristinare direttamente i dati.
Python colocalizzato
Python colocalizzato è un'API JAX open source che consente di eseguire codice Python specificato dall'utente direttamente sugli host TPU o GPU, il che è più semplice in JAX multi-controller. Ciò consente di evitare il trasferimento di dati tra il client e le macchine TPU per attività che richiedono un maggiore utilizzo di risorse di calcolo, come il caricamento dei dati e il checkpointing. Per configurare il cluster Pathways per eseguire l'API Python JAX collocate, segui le istruzioni nel file README di Python collocato. Queste istruzioni spiegano come avviare un sidecar Python collocato insieme ai worker Pathways.
Caricamento dei dati
Durante l'addestramento, carichiamo ripetutamente batch da un set di dati per inserirli nel modello. Disporre di un caricatore di dati asincrono efficiente che distribuisca il batch tra gli host è importante per evitare di privare gli acceleratori di lavoro. Quando esegui l'addestramento con Pathways, il caricatore dei dati viene eseguito su una VM CPU (a differenza di una VM TPU che viene utilizzata nelle configurazioni multi-controller) e invia i dati alle VM TPU. Ciò comporta una latenza maggiore nella lettura dei dati, ma questo problema viene parzialmente mitigato leggendo in anticipo X batch sull'host CPU e inviando i dati letti in modo asincrono alle TPU. Questa soluzione è sufficiente quando viene eseguita su scala ridotta o media.
Per prestazioni ottimali su larga scala, ti consigliamo vivamente di collocare la pipeline di dati di input utilizzando Python colocalizzato per eseguire la pipeline di dati direttamente sugli acceleratori. In questo modo si elimina il collo di bottiglia della CPU e si sfruttano le interconnessioni veloci della TPU per il trasferimento dei dati.
Puoi trovare un'implementazione di riferimento della migrazione di una pipeline di input basata su TFDS nell'implementazione di RemoteIterator in
multihost_dataloading.py.
Questa implementazione funziona sia su JAX multi-controller sia su Pathways in modo distribuito utilizzando l'API Python JAX collocate.
Controllo delle versioni di Jax
Le release di Pathways sono strettamente accoppiate alle versioni di JAX per garantire compatibilità
e stabilità. Per evitare potenziali problemi, verifica che gli artefatti di Pathways
e la tua versione di JAX siano allineati. Ogni release di Pathways specifica chiaramente le versioni di JAX compatibili tramite un tag nel formato jax-<version>.
Cache di compilazione
La cache di compilazione persistente di Pathways è una funzionalità che consente ai server Pathways
di archiviare gli eseguibili XLA compilati in una posizione persistente, ad esempio
Cloud Storage, per evitare compilazioni ridondanti. Questa funzionalità è attivata per impostazione predefinita. La posizione della cache viene passata come flag --gcs_scratch_location
ai container di Resource Manager e Pathways Worker. Per ridurre al minimo i costi di archiviazione associati, la cache collega un criterio del ciclo di vita alla posizione Cloud Storage. Esiste un limite di 50 policy per
bucket Cloud Storage. Pertanto, ti consigliamo di utilizzare una posizione Cloud Storage comune in tutti i workload.
Questa cache è simile alla cache di compilazione JAX
che è disabilitata da pathwaysutils.initialize() per i carichi di lavoro Pathways.
Profilazione
Puoi utilizzare JAX Profiler per generare tracce di un programma JAX. Esistono due metodi comuni supportati da Pathways:
- Pubblicità programmatica
- Acquisire i profili in modo programmatico dal codice JAX
- Manuale
- Acquisizione di profili on demand dopo l'avvio del server Profiler dal codice JAX
In entrambi i casi, i profili vengono scritti in un bucket Cloud Storage. Nel bucket Cloud Storage verranno creati più file di traccia, potenzialmente in cartelle con timestamp diversi, ad esempio:
- Processo Python principale che ha richiamato la traccia (in genere la VM del notebook):
<jax-client-vm-name>.xplane.pb - Proxy IFRT di Pathways:
client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Resource Manager di Pathways:
server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Operatore/i di Pathways:
server.*<tpu-node-name>.xplane.pb
Questi file di traccia possono essere analizzati con TensorBoard eseguendo il seguente comando. Per saperne di più su TensorBoard e su tutti i suoi strumenti di profilazione, consulta Ottimizzare il rendimento di TensorFlow utilizzando Profiler.
# verify trace files are present gsutil ls -l -r gs://BUCKET/PREFIX # View on tensorboard tensorboard --logdir=gs://BUCKET/PREFIX
Sostituisci quanto segue:
BUCKET: un bucket Cloud Storage in cui archiviare i file di tracciaPREFIX: un percorso all'interno del bucket Cloud Storage in cui archiviare i file di traccia
Acquisizione programmatica dei profili
Acquisisci un profilo dall'interno del codice. I profili vengono salvati all'interno di
gs://<bucket>/<prefix> in una directory con timestamp
import jax import pathwaysutils pathwaysutils.initialize() jax.profiler.start_trace("gs://BUCKET/PREFIX") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
Acquisizione manuale del profilo
Per acquisire manualmente le informazioni del profilo, devi avviare il server del profiler dal codice Python:
import jax
import pathwaysutils
pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)
# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op
Mentre il server Profiler è in esecuzione, puoi acquisire un profilo ed esportare i dati nella posizione di destinazione di Cloud Storage:
export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>
curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port
Puoi trovare informazioni sui tempi per i metodi del client proxy IFRT come Compile e
Execute all'interno della traccia del tuo programma. Questi eventi, che descrivono in dettaglio le
interazioni con il server gRPC proxy IFRT durante la compilazione e l'esecuzione,
vengono visualizzati nel thread denominato GrpcClientSessionUserFuturesWorkQueue. Esaminando
questo thread nella traccia, puoi ottenere informazioni sul rendimento di queste
operazioni.
Flag XLA
Quando utilizzi Pathways, devi impostare i flag XLA nel container pathways-proxy. Puoi farlo utilizzando XPK o l'API PathwaysJob.
Quando utilizzi XPK, imposta i flag XLA come segue:
--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"
Quando utilizzi l'API PathwaysJob, imposta i flag XLA come segue:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customFlags: - --xla_flag_1=value1 - --xla_flag_2=value2
Sostituisci quanto segue:
USER: il tuo Google Cloud nome utentevalue[n]: i flag XLA che vuoi impostare
Dump HLO
Per analizzare nel dettaglio gli input di High Level Optimizer (HLO) forniti al compilatore XLA, puoi configurare Pathways per eseguire il dump di HLO in una posizione Cloud Storage specificata nel seguente modo:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customEnv: - name: XLA_FLAGS value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"
Passaggi successivi
- Crea un cluster GKE con Pathways
- Inferenza multihost con Pathways
- Carichi di lavoro batch con percorsi
- Modalità interattiva di Pathways
- Formazione resiliente con Pathways
- Percorsi di risoluzione dei problemi