Berechnung mit JAX auf einer Cloud TPU-VM ausführen
Dieses Dokument bietet eine kurze Einführung in die Arbeit mit JAX und Cloud TPU beschrieben.
Hinweis
Bevor Sie die Befehle in diesem Dokument ausführen, müssen Sie ein Google Cloud-Konto erstellen, die Google Cloud CLI installieren und den gcloud
-Befehl konfigurieren. Weitere Informationen finden Sie unter Cloud TPU-Umgebung einrichten.
Cloud TPU-VM mit gcloud
erstellen
Definieren Sie einige Umgebungsvariablen, um die Verwendung von Befehlen zu erleichtern.
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east5-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
Beschreibungen von Umgebungsvariablen
Variable Beschreibung PROJECT_ID
Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues Projekt. TPU_NAME
Der Name der TPU ZONE
Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen. ACCELERATOR_TYPE
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen. RUNTIME_VERSION
Die Softwareversion der Cloud TPU. Erstellen Sie Ihre TPU-VM, indem Sie den folgenden Befehl in einer Cloud Shell oder Ihrem Computerterminal ausführen, in dem die Google Cloud CLI installiert ist.
$ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Verbindung zur Cloud TPU-VM herstellen
Stellen Sie mit dem folgenden Befehl eine SSH-Verbindung zu Ihrer TPU-VM her:
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Wenn Sie keine SSH-Verbindung zu einer TPU-VM herstellen können, liegt das möglicherweise daran, dass die TPU-VM keine externe IP-Adresse hat. Wenn Sie ohne externe IP-Adresse auf eine TPU-VM zugreifen möchten, folgen Sie der Anleitung unter Verbindung zu einer TPU-VM ohne öffentliche IP-Adresse herstellen.
JAX auf der Cloud TPU-VM installieren
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Systemprüfung
Prüfen Sie, ob JAX auf die TPU zugreifen und grundlegende Vorgänge ausführen kann:
Starten Sie den Python 3-Interpreter:
(vm)$ python3
>>> import jax
Prüfen Sie die Anzahl der verfügbaren TPU-Kerne:
>>> jax.device_count()
Die Anzahl der TPU-Kerne wird angezeigt. Die angezeigte Anzahl der Kerne hängt von der verwendeten TPU-Version ab. Weitere Informationen finden Sie unter TPU-Versionen.
Berechnung durchführen
>>> jax.numpy.add(1, 1)
Das Ergebnis von „numpy.add“ wird angezeigt.
Dia Befehlsausgabe lautet:
Array(2, dtype=int32, weak_type=True)
Python-Interpreter beenden
>>> exit()
JAX-Code auf einer TPU-VM ausführen
Sie können jetzt beliebigen JAX-Code ausführen. Die Flax-Beispiele sind ein guter Ausgangspunkt, um Standard-ML-Modelle in JAX auszuführen. So trainieren Sie beispielsweise ein einfaches MNIST-Convolutional-Network:
Installieren Sie die Abhängigkeiten für Flax-Beispiele:
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
Installieren Sie FLAX:
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
Führen Sie das Flax-MNIST-Trainings-Script aus:
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
Das Script lädt das Dataset herunter und startet das Training. Die Ausgabe des Scripts sollte in etwa so aussehen:
I0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
Bereinigen
Mit den folgenden Schritten vermeiden Sie, dass Ihrem Google Cloud -Konto die auf dieser Seite verwendeten Ressourcen in Rechnung gestellt werden:
Wenn Sie mit Ihrer TPU-VM fertig sind, führen Sie die folgenden Schritte aus, um Ihre Ressourcen zu bereinigen.
Trennen Sie die Verbindung zur Cloud TPU-Instanz, sofern noch nicht geschehen:
(vm)$ exit
In der Eingabeaufforderung sollte nun Nutzername@Projektname angezeigt werden, was bedeutet, dass Sie sich in der Cloud Shell befinden.
Löschen Sie Ihre Cloud TPU:
$ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Überprüfen Sie mit dem folgenden Befehl, ob die Ressourcen gelöscht wurden. Achten Sie darauf, dass Ihre TPU nicht mehr aufgeführt wird. Der Löschvorgang kann einige Minuten dauern.
$ gcloud compute tpus tpu-vm list \ --zone=$ZONE
Hinweise zur Leistung
Im Folgenden finden Sie einige wichtige Details, die insbesondere für die Verwendung von TPUs in JAX relevant sind.
Padding
Eine der häufigsten Ursachen für eine langsame Ausführung auf TPUs ist versehentliches Padding:
- Arrays in der Cloud TPU sind gekachelt. Dies bedeutet, dass eine der Dimensionen auf ein Vielfaches von 8 und eine andere auf ein Vielfaches von 128 aufgefüllt wird.
- Die Matrixmultiplikationseinheit (Matrix Multiplication Unit, MXU) funktioniert am besten mit Paaren großer Matrizen, die die Notwendigkeit von Padding minimieren.
bfloat16-dtype
Standardmäßig verwendet die Matrixmultiplikation in JAX auf TPUs bfloat16 mit float32-Akkumulation. Dies kann mit dem Precision-Argument für relevante jax.numpy
-Funktionsaufrufe (matmul, dot, einsum usw.) gesteuert werden. Beispiele:
precision=jax.lax.Precision.DEFAULT
: verwendet die gemischte bfloat16-Precision (am schnellsten)precision=jax.lax.Precision.HIGH
: verwendet mehrere MXU-Durchläufe, um eine höhere Precision zu erreichenprecision=jax.lax.Precision.HIGHEST
: verwendet noch mehr MXU-Durchläufe, um eine vollständige float32-Precision zu erreichen
JAX fügt außerdem den bfloat16-dtype hinzu, mit dem Sie Arrays explizit in bfloat16
umwandeln können. Beispiel: jax.numpy.array(x, dtype=jax.numpy.bfloat16)
Nächste Schritte
Weitere Informationen zu Cloud TPU finden Sie hier: