Esecuzione del codice JAX nelle sezioni TPU

Prima di eseguire i comandi in questo documento, assicurati di aver seguito le istruzioni riportate in Configurare un account e un progetto Cloud TPU.

Dopo aver eseguito il codice JAX su una singola scheda TPU, puoi scalare il codice eseguendolo su una sezione TPU. Le sezioni TPU sono più schede TPU collegate tra loro tramite connessioni di rete dedicate ad alta velocità. Questo documento è un'introduzione all'esecuzione del codice JAX nelle sezioni TPU. Per informazioni più approfondite, consulta Utilizzo di JAX in ambienti multi-host e multi-processo.

Ruoli obbligatori

Per ottenere le autorizzazioni necessarie per creare una TPU e connetterti a quest'ultima tramite SSH, chiedi all'amministratore di concederti i seguenti ruoli IAM sul progetto:

Per saperne di più sulla concessione dei ruoli, consulta Gestisci l'accesso a progetti, cartelle e organizzazioni.

Potresti anche riuscire a ottenere le autorizzazioni richieste tramite i ruoli personalizzati o altri ruoli predefiniti.

Crea uno slice Cloud TPU

  1. Crea alcune variabili di ambiente:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5litepod-32
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    Descrizioni delle variabili di ambiente

    Variabile Descrizione
    PROJECT_ID L'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo.
    TPU_NAME Il nome della TPU.
    ZONE La zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU.
    ACCELERATOR_TYPE Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU.
    RUNTIME_VERSION La versione software di Cloud TPU.

  2. Crea una sezione TPU utilizzando il comando gcloud. Ad esempio, per creare una slice v5litepod-32, utilizza il seguente comando:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE}  \
        --version=${RUNTIME_VERSION} 

Installa JAX sullo slice

Dopo aver creato la sezione TPU, devi installare JAX su tutti gli host della sezione TPU. Puoi farlo utilizzando il comando gcloud compute tpus tpu-vm ssh con i parametri --worker=all e --commamnd.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Esecuzione del codice JAX sulla sezione

Per eseguire il codice JAX su una sezione TPU, devi eseguire il codice su ogni host nella sezione TPU. La chiamata jax.device_count() smette di rispondere finché non viene chiamata su ogni host nella sezione. L'esempio seguente mostra come eseguire un calcolo JAX su una sezione TPU.

Prepara il codice

È necessaria la versione gcloud >= 344.0.0 (per il comando scp). Utilizza gcloud --version per controllare la versione di gcloud ed esegui gcloud components upgrade, se necessario.

Crea un file denominato example.py con il seguente codice:


import jax

# Initialize the slice
jax.distributed.initialize()

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copia example.py in tutte le VM worker TPU nello slice

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

Se non hai mai utilizzato il comando scp, potresti visualizzare un errore simile al seguente:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Per risolvere l'errore, esegui il comando ssh-add come visualizzato nel messaggio di errore ed esegui di nuovo il comando.

Esegui il codice sulla sezione

Avvia il programma example.py su ogni VM:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

Output (prodotto con una sezione v5litepod-32):

global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]

Esegui la pulizia

Al termine dell'utilizzo della VM TPU, segui questi passaggi per pulire le risorse.

  1. Elimina le risorse Cloud TPU e Compute Engine.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  2. Verifica che le risorse siano state eliminate eseguendo gcloud compute tpus execution-groups list. L'eliminazione potrebbe richiedere diversi minuti. L'output del seguente comando non deve includere nessuna delle risorse create in questo tutorial:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}