Esegui un calcolo su una VM Cloud TPU utilizzando JAX
Questo documento fornisce una breve introduzione all'utilizzo di JAX e Cloud TPU.
Prima di iniziare
Prima di eseguire i comandi in questo documento, devi creare un account Google Cloud, installare Google Cloud CLI e configurare il comando gcloud. Per
maggiori informazioni, consulta Configurare l'ambiente Cloud TPU.
Ruoli obbligatori
Per ottenere le autorizzazioni necessarie per creare una TPU e connetterti a quest'ultima tramite SSH, chiedi all'amministratore di concederti i seguenti ruoli IAM sul progetto:
-
TPU Admin (
roles/tpu.admin) -
Utente Service Account (
roles/iam.serviceAccountUser) -
Compute Viewer (
roles/compute.viewer)
Per saperne di più sulla concessione dei ruoli, consulta Gestisci l'accesso a progetti, cartelle e organizzazioni.
Potresti anche riuscire a ottenere le autorizzazioni richieste tramite i ruoli personalizzati o altri ruoli predefiniti.
Crea una VM Cloud TPU utilizzando gcloud
Definisci alcune variabili di ambiente per semplificare l'utilizzo dei comandi.
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east5-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
Descrizioni delle variabili di ambiente
Variabile Descrizione PROJECT_IDL'ID progetto Google Cloud . Utilizza un progetto esistente o creane uno nuovo. TPU_NAMEIl nome della TPU. ZONELa zona in cui creare la VM TPU. Per saperne di più sulle zone supportate, consulta Regioni e zone TPU. ACCELERATOR_TYPEIl tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per saperne di più sui tipi di acceleratore supportati per ogni versione di TPU, consulta la sezione Versioni di TPU. RUNTIME_VERSIONLa versione software di Cloud TPU. Crea la VM TPU eseguendo questo comando da Cloud Shell o dal terminale del computer in cui è installata Google Cloud CLI.
$ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Connettiti alla VM Cloud TPU
Connettiti alla VM TPU tramite SSH utilizzando il seguente comando:
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Se non riesci a connetterti a una VM TPU tramite SSH, il problema potrebbe essere che la VM TPU non ha un indirizzo IP esterno. Per accedere a una VM TPU senza un indirizzo IP esterno, segui le istruzioni riportate in Connettiti a una VM TPU senza un indirizzo IP pubblico.
Installa JAX sulla VM Cloud TPU
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Controllo del sistema
Verifica che JAX possa accedere alla TPU ed eseguire operazioni di base:
Avvia l'interprete Python 3:
(vm)$ python3>>> import jax
Visualizza il numero di core TPU disponibili:
>>> jax.device_count()
Viene visualizzato il numero di core TPU. Il numero di core visualizzati dipende dalla versione della TPU che stai utilizzando. Per saperne di più, consulta Versioni TPU.
Eseguire un calcolo
>>> jax.numpy.add(1, 1)
Viene visualizzato il risultato dell'aggiunta di NumPy:
Output del comando:
Array(2, dtype=int32, weak_type=True)
Esci dall'interprete Python
>>> exit()
Esecuzione del codice JAX su una VM TPU
Ora puoi eseguire qualsiasi codice JAX. Gli esempi di Flax sono un ottimo punto di partenza per l'esecuzione di modelli di machine learning standard in JAX. Ad esempio, per addestrare una rete convoluzionale MNIST di base:
Installa le dipendenze degli esempi di Flax:
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
Installa Flax:
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
Esegui lo script di addestramento Flax MNIST:
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
Lo script scarica il set di dati e inizia l'addestramento. L'output dello script dovrebbe essere simile al seguente:
I0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
Esegui la pulizia
Per evitare che al tuo account Google Cloud vengano addebitati costi relativi alle risorse utilizzate in questa pagina, segui questi passaggi.
Al termine dell'utilizzo della VM TPU, segui questi passaggi per pulire le risorse.
Disconnettiti dall'istanza Cloud TPU, se non l'hai ancora fatto:
(vm)$ exit
Il prompt dovrebbe ora essere username@projectname, a indicare che ti trovi in Cloud Shell.
Elimina la tua Cloud TPU:
$ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Verifica che le risorse siano state eliminate eseguendo il seguente comando. Assicurati che la TPU non sia più elencata. L'eliminazione può richiedere qualche minuto.
$ gcloud compute tpus tpu-vm list \ --zone=$ZONE
Note sul rendimento
Ecco alcuni dettagli importanti particolarmente rilevanti per l'utilizzo delle TPU in JAX.
Spaziatura interna
Una delle cause più comuni di prestazioni lente sulle TPU è l'introduzione di padding involontario:
- Gli array in Cloud TPU sono suddivisi in riquadri. Ciò comporta l'aggiunta di spazio interno a una delle dimensioni in modo che sia un multiplo di 8 e a un'altra dimensione in modo che sia un multiplo di 128.
- L'unità di moltiplicazione matriciale funziona meglio con coppie di matrici di grandi dimensioni che riducono al minimo la necessità di padding.
Tipo di dati bfloat16
Per impostazione predefinita, la moltiplicazione di matrici in JAX sulle TPU utilizza bfloat16
con accumulo float32. Questo può essere controllato con l'argomento di precisione nelle chiamate di funzione jax.numpy pertinenti (matmul, dot, einsum e così via). In particolare:
precision=jax.lax.Precision.DEFAULT: utilizza la precisione mista bfloat16 (più veloce)precision=jax.lax.Precision.HIGH: utilizza più passaggi MXU per ottenere una maggiore precisioneprecision=jax.lax.Precision.HIGHEST: utilizza ancora più passaggi MXU per ottenere una precisione float32 completa
JAX aggiunge anche il tipo di dati bfloat16, che puoi utilizzare per eseguire il cast esplicito degli array in
bfloat16. Ad esempio,
jax.numpy.array(x, dtype=jax.numpy.bfloat16).
Passaggi successivi
Per ulteriori informazioni su Cloud TPU, vedi: