Karena sifat terdistribusi JAX dengan Pathways, beberapa operasi mungkin tidak menskalakan dengan baik karena overhead komunikasi. Meskipun Pathways meminimalkan overhead ini dengan fitur seperti pengiriman asinkron, ada beberapa hal yang perlu Anda ketahui saat mem-porting beban kerja JAX ke Pathways atau menskalakan beban kerja JAX dengan Pathways ke sejumlah besar akselerator.
Sebelum memulai
Pastikan Anda memiliki:
- Alat Kubernetes yang terinstal
- Menginstal gcloud CLI
- Mengaktifkan TPU API
- Mengaktifkan Google Kubernetes Engine API
Indeks proses
JAX dengan Pathways memperlakukan semua perangkat di seluruh cluster Pathways Anda sebagai lokal. Hal ini menyederhanakan pengelolaan perangkat dan memungkinkan JAX memanfaatkan semua resource yang tersedia. Dalam praktiknya, hal ini berarti:
jax.process_index()selalu 0 untuk semua perangkat.jax.devices()danjax.local_devices()menampilkan semua perangkat TPU di seluruh tugas.
Jenis hardware dan kolokasi
Untuk mendapatkan performa terbaik, tempatkan semua komponen Pathways dan tugas pengguna di
zona cloud Google Cloud yang sama. Gunakan CPU besar seperti proxy IFRT dan pengelola
resource. Sebaiknya gunakan n2-standard-64 khusus yang dilengkapi dengan 64 vCPU dan memori 256 GB.
PathwaysUtils
Pathways-utils adalah repositori GitHub berbasis Python yang menyediakan utilitas dan alat penting yang memungkinkan Anda menyederhanakan deployment dan eksekusi workload JAX pada arsitektur Pathways on Cloud. Paket ini menangani adaptasi yang diperlukan untuk lingkungan cloud, sehingga developer JAX dapat berfokus pada alur kerja machine learning inti mereka dengan konfigurasi khusus platform yang minimal. Secara khusus, fitur ini menawarkan:
- Backend JAX "proxy": backend kustom ini memungkinkan aplikasi JAX Anda menggunakan infrastruktur Pathways dengan menyetel variabel lingkungan
JAX_PLATFORMS=proxy. - Utilitas Pembuatan Profil Terintegrasi: kemampuan pembuatan profil yang memungkinkan Anda memahami performa aplikasi Anda. Dengan menggunakan API pembuatan profil JAX standar seperti
jax.profiler.start_tracedanjax.profiler.start_server, Anda dapat membuat profil tidak hanya kode JAX, tetapi juga komponen Pathways yang mendasarinya, sehingga memberikan gambaran menyeluruh tentang eksekusi dalam lingkungan cloud. - Checkpoint Terdistribusi dengan Orbax: handler checkpoint Orbax kustom yang memungkinkan Anda menggunakan checkpoint terdistribusi dan memulihkan checkpoint saat menggunakan library Orbax dalam lingkungan Pathways. Integrasi ini bertujuan untuk berfungsi tanpa memerlukan perubahan pada kode pembuatan checkpoint Orbax yang ada selama kode tersebut mengimpor
pathwaysutils. - Primitif Pelatihan Elastis: menyediakan primitif pelatihan elastis dasar yang dapat Anda gunakan untuk membangun alur kerja pelatihan yang andal dan skalabel menggunakan Pathways. Primitif ini memungkinkan tugas pelatihan Anda beradaptasi secara dinamis terhadap perubahan pada resource yang tersedia, sehingga meningkatkan efisiensi dan ketahanan di lingkungan cloud.
Melakukan checkpoint
Orbax diuji secara menyeluruh dengan Pathways untuk
pembuatan dan pemulihan checkpoint terdistribusi dengan Cloud Storage. Saat Anda
memanggil import pathwaysutils; pathwaysutils.initialize() di train.py, ArrayHandler kustom
didaftarkan yang secara efisien menangani operasi checkpoint
melalui proxy IFRT, sehingga memungkinkan pekerja Pathways di akselerator untuk menyimpan dan memulihkan data secara langsung.
Python yang ditempatkan bersama
Python yang Ditempatkan Bersama adalah JAX API open source yang memungkinkan Anda menjalankan kode Python yang ditentukan pengguna langsung di host TPU atau GPU, yang lebih mudah dalam JAX multi-pengontrol. Hal ini memungkinkan tugas yang lebih intensif komputasi, seperti pemuatan data dan pembuatan titik pemeriksaan, untuk menghindari transfer data antara klien dan mesin TPU. Untuk mengonfigurasi cluster Pathways agar menjalankan API JAX Python yang ditempatkan bersama, ikuti petunjuk di README Python yang ditempatkan bersama Petunjuk ini menjelaskan cara memulai sidecar Python yang ditempatkan bersama di samping pekerja Pathways Anda.
Pemuatan data
Selama pelatihan, kita berulang kali memuat batch dari set data untuk dimasukkan ke dalam model. Memiliki pemuat data asinkron yang efisien yang membagi batch di seluruh host penting untuk menghindari kekurangan tugas akselerator. Saat menjalankan pelatihan dengan Pathways, pemuat data berjalan di VM CPU (tidak seperti VM TPU yang digunakan pada penyiapan multi-pengontrol) dan mengirimkan data ke VM TPU. Hal ini menimbulkan latensi yang lebih tinggi dalam membaca data, tetapi hal itu sebagian dimitigasi dengan membaca X jumlah batch di host CPU dan mengirimkan data yang dibaca secara asinkron ke TPU. Solusi ini sudah cukup saat berjalan dalam skala kecil hingga sedang.
Untuk performa yang optimal dalam skala besar, sebaiknya lakukan penempatan bersama pipeline data input Anda dengan menggunakan Python yang ditempatkan bersama untuk menjalankan pipeline data Anda langsung di akselerator. Hal ini menghilangkan hambatan CPU dan memanfaatkan interkoneksi cepat TPU untuk transfer data.
Anda dapat menemukan implementasi referensi untuk memigrasikan pipeline input berbasis TFDS di implementasi RemoteIterator dalam
multihost_dataloading.py.
Implementasi ini berfungsi di JAX dan Pathways multi-pengontrol secara terdistribusi menggunakan colocated Python JAX API.
Pembuatan Versi Jax
Rilis Pathways sangat terkait dengan versi JAX untuk memastikan kompatibilitas dan stabilitas. Untuk menghindari potensi masalah, pastikan artefak Pathways dan versi JAX Anda selaras. Setiap rilis Pathways secara jelas menentukan
versi JAX yang kompatibel melalui tag dalam bentuk jax-<version>.
Cache Kompilasi
Cache kompilasi persisten Pathways adalah fitur yang memungkinkan server Pathways menyimpan executable XLA yang dikompilasi di lokasi persisten, seperti Cloud Storage, untuk menghindari kompilasi yang berlebihan. Fitur ini diaktifkan secara
default. Lokasi cache diteruskan sebagai tanda --gcs_scratch_location
ke container pekerja Pathways dan pengelola resource. Untuk meminimalkan biaya penyimpanan terkait, cache melampirkan kebijakan siklus proses ke lokasi Cloud Storage. Ada batas 50 kebijakan per bucket Cloud Storage. Oleh karena itu, sebaiknya gunakan lokasi Cloud Storage umum di semua beban kerja.
Cache ini mirip dengan cache kompilasi JAX
yang dinonaktifkan oleh pathwaysutils.initialize() untuk workload Pathways.
Pembuatan profil
Anda dapat menggunakan profiler JAX untuk membuat rekaman aktivitas program JAX. Ada dua cara umum yang didukung dengan Pathways:
- Terprogram
- Mengambil profil secara terprogram dari kode JAX Anda
- Manual
- Merekam profil sesuai permintaan setelah memulai server profiler dari kode JAX Anda
Dalam kedua kasus tersebut, profil ditulis ke bucket Cloud Storage. Akan ada beberapa file rekaman aktivitas yang dibuat di bucket Cloud Storage yang mungkin berada di folder stempel waktu yang berbeda, misalnya:
- Proses Python utama yang memanggil rekaman aktivitas (biasanya VM notebook Anda):
<jax-client-vm-name>.xplane.pb - Proxy IFRT Pathways:
client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Pengelola Resource Pathways:
server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Pekerja Pathways:
server.*<tpu-node-name>.xplane.pb
File rekaman aktivitas ini dapat dianalisis dengan TensorBoard dengan menjalankan perintah berikut. Untuk mengetahui informasi selengkapnya tentang TensorBoard dan semua alat profiling-nya, lihat Mengoptimalkan performa TensorFlow menggunakan Profiler.
# verify trace files are present gsutil ls -l -r gs://BUCKET/PREFIX # View on tensorboard tensorboard --logdir=gs://BUCKET/PREFIX
Ganti kode berikut:
BUCKET: bucket Cloud Storage untuk menyimpan file rekaman aktivitasPREFIX: jalur dalam bucket Cloud Storage Anda untuk menyimpan file rekaman aktivitas
Pengambilan profil terprogram
Merekam profil dari dalam kode Anda. Profil disimpan di dalam
gs://<bucket>/<prefix> di bawah direktori stempel waktu
import jax import pathwaysutils pathwaysutils.initialize() jax.profiler.start_trace("gs://BUCKET/PREFIX") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
Pengambilan profil manual
Untuk merekam informasi profil secara manual, Anda harus memulai server profiler dari kode Python Anda:
import jax
import pathwaysutils
pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)
# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op
Saat server profiler berjalan, Anda dapat merekam profil dan mengekspor data ke lokasi Cloud Storage target:
export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>
curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port
Anda dapat menemukan informasi pengaturan waktu untuk metode klien proxy IFRT seperti Compile dan
Execute dalam rekaman aktivitas program Anda. Peristiwa ini, yang menjelaskan interaksi dengan server gRPC IFRT Proxy selama kompilasi dan eksekusi, muncul di thread bernama GrpcClientSessionUserFuturesWorkQueue. Dengan memeriksa
thread ini dalam rekaman aktivitas, Anda dapat memperoleh insight tentang performa
operasi ini.
Flag XLA
Saat menggunakan Pathways, Anda perlu menyetel flag XLA di container pathways-proxy. Anda dapat melakukannya menggunakan XPK atau PathwaysJob API.
Saat menggunakan XPK, tetapkan flag XLA seperti berikut:
--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"
Saat menggunakan PathwaysJob API, tetapkan flag XLA seperti berikut:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customFlags: - --xla_flag_1=value1 - --xla_flag_2=value2
Ganti kode berikut:
USER: nama pengguna Google Cloud Andavalue[n]: tanda XLA yang ingin Anda tetapkan
Dump HLO
Untuk mempelajari lebih dalam input Pengoptimal Tingkat Tinggi (HLO) yang diberikan ke kompiler XLA, Anda dapat mengonfigurasi Pathways untuk mengekspor HLO ke lokasi Cloud Storage yang ditentukan sebagai berikut:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customEnv: - name: XLA_FLAGS value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"
Langkah berikutnya
- Membuat Cluster GKE dengan Pathways
- Inferensi multihost dengan Pathways
- Batch workload dengan Pathways
- Mode interaktif jalur
- Pelatihan yang tangguh dengan Pathways
- Pemecahan Masalah Jalur