Menjalankan kode JAX di slice TPU
Sebelum menjalankan perintah dalam dokumen ini, pastikan Anda telah mengikuti petunjuk di Menyiapkan akun dan project Cloud TPU.
Setelah kode JAX Anda berjalan di satu papan TPU, Anda dapat menskalakan kode dengan menjalankannya di slice TPU. Slice TPU adalah beberapa papan TPU yang terhubung satu sama lain melalui koneksi jaringan khusus berkecepatan tinggi. Dokumen ini adalah pengantar untuk menjalankan kode JAX di slice TPU. Untuk mengetahui informasi yang lebih mendalam, lihat Menggunakan JAX di lingkungan multi-host dan multi-proses.
Peran yang diperlukan
Untuk mendapatkan izin yang Anda perlukan guna membuat TPU dan terhubung ke TPU menggunakan SSH, minta administrator Anda untuk memberi Anda peran IAM berikut di project Anda:
- Admin TPU (
roles/tpu.admin) - Pengguna Akun Layanan (
roles/iam.serviceAccountUser) - Compute Viewer (
roles/compute.viewer)
Untuk mengetahui informasi selengkapnya tentang pemberian peran, lihat Mengelola akses ke project, folder, dan organisasi.
Anda mungkin juga bisa mendapatkan izin yang diperlukan melalui peran khusus atau peran bawaan lainnya.
Membuat slice Cloud TPU
Buat beberapa variabel lingkungan:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5litepod-32 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
Deskripsi variabel lingkungan
PROJECT_ID: ID project Anda. Google Cloud Gunakan project yang ada atau buat project baru.TPU_NAME: Nama TPU.ZONE: Zona tempat pembuatan VM TPU. Untuk mengetahui informasi selengkapnya tentang zona yang didukung, lihat Region dan zona TPU.ACCELERATOR_TYPE: Jenis akselerator menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis akselerator yang didukung untuk setiap versi TPU, lihat Versi TPU.RUNTIME_VERSION: Versi software Cloud TPU.
Buat slice TPU menggunakan perintah
gcloud. Misalnya, untuk membuat slice v5litepod-32, gunakan perintah berikut:$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
Menginstal JAX di slice
Setelah membuat slice TPU, Anda harus menginstal JAX di semua host dalam slice TPU. Anda dapat melakukannya menggunakan perintah gcloud compute tpus tpu-vm ssh menggunakan parameter --worker=all dan --commamnd.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Menjalankan kode JAX di slice
Untuk menjalankan kode JAX di slice TPU, Anda harus menjalankan kode di setiap host dalam slice TPU. Panggilan jax.device_count() berhenti merespons hingga dipanggil di setiap host dalam slice. Contoh berikut mengilustrasikan cara menjalankan perhitungan JAX di slice TPU.
Menyiapkan kode
Anda memerlukan gcloud versi >= 344.0.0 (untuk perintah scp).
Gunakan gcloud --version untuk memeriksa versi gcloud, dan
jalankan gcloud components upgrade, jika diperlukan.
Buat file bernama example.py dengan kode berikut:
import jax
# Initialize the slice
jax.distributed.initialize()
# The total number of TPU cores in the slice
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
Menyalin example.py ke semua VM pekerja TPU di slice
$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \ --worker=all \ --zone=${ZONE} \ --project=${PROJECT_ID}
Jika sebelumnya Anda belum pernah menggunakan perintah scp, Anda mungkin akan melihat error yang mirip dengan berikut:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
Untuk mengatasi error, jalankan perintah ssh-add seperti yang ditampilkan dalam pesan error dan jalankan kembali perintah tersebut.
Menjalankan kode di slice
Luncurkan program example.py di setiap VM:
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="python3 ./example.py"
Output (dihasilkan dengan slice v5litepod-32):
global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]
Pembersihan
Setelah selesai menggunakan VM TPU, ikuti langkah-langkah berikut untuk membersihkan resource Anda.
Hapus resource Cloud TPU dan Compute Engine Anda.
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Verifikasi bahwa resource telah dihapus dengan menjalankan
gcloud compute tpus execution-groups list. Penghapusan mungkin memerlukan waktu beberapa menit. Output dari perintah berikut tidak boleh menyertakan resource yang dibuat dalam tutorial ini:$ gcloud compute tpus tpu-vm list --zone=${ZONE} \ --project=${PROJECT_ID}