Esecuzione del codice JAX nelle sezioni TPU
Prima di eseguire i comandi in questo documento, assicurati di aver seguito le istruzioni riportate in Configurare un account e un progetto Cloud TPU.
Dopo aver eseguito il codice JAX su una singola scheda TPU, puoi scalare il codice eseguendolo su una sezione TPU. Le sezioni TPU sono più schede TPU collegate tra loro tramite connessioni di rete dedicate ad alta velocità. Questo documento è un'introduzione all'esecuzione del codice JAX nelle sezioni TPU. Per informazioni più approfondite, consulta Utilizzo di JAX in ambienti multi-host e multi-processo.
Ruoli obbligatori
Per ottenere le autorizzazioni necessarie per creare una TPU e connetterti a quest'ultima tramite SSH, chiedi all'amministratore di concederti i seguenti ruoli IAM sul progetto:
-
TPU Admin (
roles/tpu.admin) -
Utente Service Account (
roles/iam.serviceAccountUser) -
Compute Viewer (
roles/compute.viewer)
Per saperne di più sulla concessione dei ruoli, consulta Gestisci l'accesso a progetti, cartelle e organizzazioni.
Potresti anche riuscire a ottenere le autorizzazioni richieste tramite i ruoli personalizzati o altri ruoli predefiniti.
Crea uno slice Cloud TPU
Crea alcune variabili di ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5litepod-32 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
Descrizioni delle variabili di ambiente
Variabile Descrizione PROJECT_IDL'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo. TPU_NAMEIl nome della TPU. ZONELa zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU. ACCELERATOR_TYPEIl tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU. RUNTIME_VERSIONLa versione software di Cloud TPU. Crea una sezione TPU utilizzando il comando
gcloud. Ad esempio, per creare una slice v5litepod-32, utilizza il seguente comando:$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
Installa JAX sullo slice
Dopo aver creato la sezione TPU, devi installare JAX su tutti gli host della sezione TPU. Puoi farlo utilizzando il comando gcloud compute tpus tpu-vm ssh con i parametri --worker=all e --commamnd.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Esecuzione del codice JAX sulla sezione
Per eseguire il codice JAX su una sezione TPU, devi eseguire il codice su ogni host nella
sezione TPU. La chiamata jax.device_count() smette di rispondere finché non viene
chiamata su ogni host nella sezione. L'esempio seguente mostra come eseguire un calcolo JAX su una sezione TPU.
Prepara il codice
È necessaria la versione gcloud >= 344.0.0 (per il
comando scp).
Utilizza gcloud --version per controllare la versione di gcloud ed esegui gcloud components upgrade, se necessario.
Crea un file denominato example.py con il seguente codice:
import jax
# Initialize the slice
jax.distributed.initialize()
# The total number of TPU cores in the slice
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
Copia example.py in tutte le VM worker TPU nello slice
$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \ --worker=all \ --zone=${ZONE} \ --project=${PROJECT_ID}
Se non hai mai utilizzato il comando scp, potresti visualizzare un errore simile al seguente:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
Per risolvere l'errore, esegui il comando ssh-add come visualizzato nel
messaggio di errore ed esegui di nuovo il comando.
Esegui il codice sulla sezione
Avvia il programma example.py su ogni VM:
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="python3 ./example.py"
Output (prodotto con una sezione v5litepod-32):
global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]
Esegui la pulizia
Al termine dell'utilizzo della VM TPU, segui questi passaggi per pulire le risorse.
Elimina le risorse Cloud TPU e Compute Engine.
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Verifica che le risorse siano state eliminate eseguendo
gcloud compute tpus execution-groups list. L'eliminazione potrebbe richiedere diversi minuti. L'output del seguente comando non deve includere nessuna delle risorse create in questo tutorial:$ gcloud compute tpus tpu-vm list --zone=${ZONE} \ --project=${PROJECT_ID}