Menjalankan penghitungan di VM Cloud TPU menggunakan JAX

Dokumen ini memberikan pengantar singkat tentang cara menggunakan JAX dan Cloud TPU.

Sebelum memulai

Sebelum menjalankan perintah dalam dokumen ini, Anda harus membuat akun Google Cloud, menginstal Google Cloud CLI, dan mengonfigurasi perintah gcloud. Untuk mengetahui informasi selengkapnya, lihat Menyiapkan lingkungan Cloud TPU.

Peran yang diperlukan

Agar mendapatkan izin yang diperlukan untuk membuat TPU dan terhubung ke TPU menggunakan SSH, minta administrator untuk memberi Anda peran IAM berikut di project Anda:

Untuk mengetahui informasi selengkapnya tentang pemberian peran, lihat Mengelola akses ke project, folder, dan organisasi.

Anda mungkin juga bisa mendapatkan izin yang diperlukan melalui peran khusus atau peran bawaan lainnya.

Buat VM Cloud TPU menggunakan gcloud

  1. Tentukan beberapa variabel lingkungan agar perintah lebih mudah digunakan.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east5-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    Deskripsi variabel lingkungan

    Variabel Deskripsi
    PROJECT_ID ID project Google Cloud Anda. Gunakan project yang ada atau buat project baru.
    TPU_NAME Nama TPU.
    ZONE Zona tempat VM TPU akan dibuat. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU.
    ACCELERATOR_TYPE Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat Versi TPU.
    RUNTIME_VERSION Versi software Cloud TPU.

  2. Buat VM TPU dengan menjalankan perintah berikut dari Cloud Shell atau terminal komputer tempat Google Cloud CLI diinstal.

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$RUNTIME_VERSION

Menghubungkan ke VM Cloud TPU

Hubungkan ke VM TPU Anda melalui SSH menggunakan perintah berikut:

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE

Jika Anda gagal terhubung ke VM TPU menggunakan SSH, hal ini mungkin karena VM TPU tidak memiliki alamat IP eksternal. Untuk mengakses VM TPU tanpa alamat IP eksternal, ikuti petunjuk di Menghubungkan ke VM TPU tanpa alamat IP publik.

Menginstal JAX di VM Cloud TPU

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Pemeriksaan sistem

Verifikasi bahwa JAX dapat mengakses TPU dan dapat menjalankan operasi dasar:

  1. Mulai interpreter Python 3:

    (vm)$ python3
    >>> import jax
  2. Tampilkan jumlah core TPU yang tersedia:

    >>> jax.device_count()

Jumlah core TPU ditampilkan. Jumlah core yang ditampilkan bergantung pada versi TPU yang Anda gunakan. Untuk mengetahui informasi selengkapnya, lihat versi TPU.

Melakukan perhitungan

>>> jax.numpy.add(1, 1)

Hasil penambahan numpy ditampilkan:

Output dari perintah:

Array(2, dtype=int32, weak_type=True)

Keluar dari interpreter Python

>>> exit()

Menjalankan kode JAX di VM TPU

Sekarang Anda dapat menjalankan kode JAX apa pun yang Anda inginkan. Contoh Flax adalah tempat yang tepat untuk mulai menjalankan model ML standar di JAX. Misalnya, untuk melatih jaringan konvolusional MNIST dasar:

  1. Instal dependensi contoh Flax:

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. Instal Flax:

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. Jalankan skrip pelatihan Flax MNIST:

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
        --config=configs/default.py \
        --config.learning_rate=0.05 \
        --config.num_epochs=5

Skrip akan mendownload set data dan memulai pelatihan. Output skrip akan terlihat seperti ini:

I0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Pembersihan

Agar akun Google Cloud Anda tidak dikenai biaya untuk resource yang digunakan pada halaman ini, ikuti langkah-langkah berikut.

Setelah selesai menggunakan VM TPU, ikuti langkah-langkah berikut untuk membersihkan resource Anda.

  1. Putuskan koneksi dari instance Cloud TPU, jika Anda belum melakukannya:

    (vm)$ exit

    Kini perintah Anda akan menjadi username@projectname, yang menunjukkan Anda berada dalam Cloud Shell.

  2. Hapus Cloud TPU Anda:

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE
  3. Pastikan resource telah dihapus dengan menjalankan perintah berikut. Pastikan TPU Anda tidak lagi tercantum. Proses penghapusan mungkin memerlukan waktu beberapa menit.

    $ gcloud compute tpus tpu-vm list \
        --zone=$ZONE

Catatan performa

Berikut beberapa detail penting yang sangat relevan untuk menggunakan TPU di JAX.

Padding

Salah satu penyebab paling umum performa lambat di TPU adalah pengenalan padding yang tidak disengaja:

  • Array di Cloud TPU disusun secara berpetak. Hal ini memerlukan pengisihan salah satu dimensi ke kelipatan 8, dan dimensi lain ke kelipatan 128.
  • Unit perkalian matriks berfungsi paling baik dengan pasangan matriks besar yang meminimalkan kebutuhan untuk padding.

dtype bfloat16

Secara default, perkalian matriks di JAX pada TPU menggunakan bfloat16 dengan akumulasi float32. Hal ini dapat dikontrol dengan argumen presisi pada panggilan fungsi jax.numpy yang relevan (matmul, dot, einsum, dll.). Pada khususnya:

  • precision=jax.lax.Precision.DEFAULT: menggunakan presisi bfloat16 campuran (tercepat)
  • precision=jax.lax.Precision.HIGH: menggunakan beberapa lintasan MXU untuk mencapai presisi yang lebih tinggi
  • precision=jax.lax.Precision.HIGHEST: menggunakan lebih banyak lagi iterasi MXU untuk mencapai presisi float32 penuh

JAX juga menambahkan dtype bfloat16, yang dapat Anda gunakan untuk secara eksplisit melakukan transmisi array ke bfloat16. Contoh, jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Langkah berikutnya

Untuk mengetahui informasi selengkapnya tentang Cloud TPU, lihat: