Karena sifat JAX yang terdistribusi dengan Pathways, beberapa operasi mungkin tidak dapat diskalakan dengan baik karena overhead komunikasi. Meskipun Pathways meminimalkan overhead ini dengan fitur seperti pengiriman asinkron, ada beberapa hal yang perlu Anda ketahui saat mem-porting workload JAX ke Pathways atau menskalakan workload JAX dengan Pathways ke sejumlah besar akselerator.
Sebelum memulai
Pastikan Anda memiliki:
- Menginstal alat Kubernetes
- Menginstal gcloud CLI
- Mengaktifkan TPU API
- Mengaktifkan Google Kubernetes Engine API
Indeks proses
JAX dengan Pathways memperlakukan semua perangkat di seluruh cluster Pathways Anda sebagai lokal. Hal ini menyederhanakan pengelolaan perangkat dan memungkinkan JAX menggunakan semua resource yang tersedia. Dalam praktiknya, hal ini berarti:
jax.process_index()selalu 0 untuk semua perangkat.jax.devices()danjax.local_devices()menampilkan semua perangkat TPU di seluruh tugas.
Jenis hardware dan kolokasi
Untuk mendapatkan performa terbaik, tempatkan semua komponen Pathways dan tugas pengguna di zona
cloud yang Google Cloud sama. Gunakan CPU besar seperti proxy IFRT dan pengelola resource. Sebaiknya gunakan setidaknya n2-standard-64 khusus yang dilengkapi dengan 64 vCPU dan memori 256 GB.
PathwaysUtils
Pathways-utils adalah repositori GitHub berbasis Python yang menyediakan utilitas dan alat penting yang memungkinkan Anda menyederhanakan deployment dan eksekusi workload JAX pada arsitektur Pathways di Cloud. Paket ini menangani adaptasi yang diperlukan untuk lingkungan cloud, sehingga developer JAX dapat berfokus pada alur kerja machine learning inti mereka dengan konfigurasi khusus platform yang minimal. Secara khusus, paket ini menawarkan:
- Backend JAX "proxy": backend kustom ini memungkinkan aplikasi JAX Anda menggunakan infrastruktur Pathways dengan menetapkan variabel lingkungan
JAX_PLATFORMS=proxy. - Utilitas Pembuatan Profil Terintegrasi: kemampuan pembuatan profil yang memungkinkan Anda memahami performa aplikasi. Dengan menggunakan JAX profiling API standar seperti
jax.profiler.start_tracedanjax.profiler.start_server, Anda dapat membuat profil tidak hanya kode JAX, tetapi juga komponen Pathways yang mendasarinya, sehingga memberikan tampilan eksekusi yang holistik dalam lingkungan cloud. - Checkpointing Terdistribusi dengan Orbax: pengendali checkpoint Orbax kustom yang memungkinkan Anda menggunakan checkpoint terdistribusi dan memulihkan checkpoint saat menggunakan library Orbax dalam lingkungan Pathways. Integrasi ini bertujuan untuk berfungsi tanpa memerlukan perubahan apa pun pada kode checkpointing Orbax yang ada selama kode tersebut mengimpor
pathwaysutils. - Primitif Pelatihan Elastis: menyediakan primitif pelatihan elastis dasar yang dapat Anda gunakan untuk membuat alur kerja pelatihan yang kuat dan dapat diskalakan menggunakan Pathways. Primitif ini memungkinkan tugas pelatihan Anda beradaptasi secara dinamis terhadap perubahan resource yang tersedia, sehingga meningkatkan efisiensi dan ketahanan di lingkungan cloud.
Checkpointing
Orbax diuji secara menyeluruh dengan Pathways untuk
checkpointing dan pemulihan terdistribusi dengan Cloud Storage. Saat Anda
menetapkan variabel lingkungan ENABLE_PATHWAYS_PERSISTENCE=1 dan memanggil
import pathwaysutils; pathwaysutils.initialize() di train.py, ArrayHandler kustom akan didaftarkan yang menangani operasi checkpoint secara efisien
melalui proxy IFRT, sehingga pekerja Pathways di akselerator dapat langsung menyimpan dan memulihkan data.
Python yang ditempatkan bersama
Python yang ditempatkan bersama adalah JAX API open source yang memungkinkan Anda menjalankan kode Python yang ditentukan pengguna langsung di host TPU atau GPU, yang lebih mudah di multi-pengontrol JAX. Hal ini memungkinkan tugas yang lebih intensif komputasi, seperti pemuatan data dan checkpointing, untuk menghindari transfer data antara klien dan mesin TPU. Untuk mengonfigurasi cluster Pathways agar menjalankan Python JAX API yang ditempatkan bersama, ikuti petunjuk di README Python yang ditempatkan bersama. Petunjuk ini menjelaskan cara memulai sidecar Python yang ditempatkan bersama di samping pekerja Pathways Anda.
Pemuatan data
Selama pelatihan, kita berulang kali memuat batch dari set data untuk dimasukkan ke dalam model. Memiliki pemuat data asinkron yang efisien yang membagi batch di seluruh host penting untuk menghindari akselerator yang kekurangan pekerjaan. Saat menjalankan pelatihan dengan Pathways, pemuat data berjalan di VM CPU (tidak seperti VM TPU yang digunakan pada penyiapan multi-pengontrol) dan mengirimkan data ke VM TPU. Hal ini menyebabkan latensi yang lebih tinggi dalam membaca data, tetapi hal tersebut dikurangi sebagian dengan membaca X batch di host CPU dan mengirimkan data yang dibaca secara asinkron ke TPU. Solusi ini cukup saat berjalan dalam skala kecil hingga menengah.
Untuk performa optimal dalam skala besar, sebaiknya tempatkan bersama pipeline data input Anda dengan menggunakan Python yang ditempatkan bersama untuk menjalankan pipeline data Anda langsung di akselerator. Hal ini akan menghilangkan bottleneck CPU dan memanfaatkan interkoneksi cepat TPU untuk transfer data.
Anda dapat menemukan implementasi referensi untuk memigrasikan pipeline input berbasis TFDS
dalam implementasi RemoteIterator di
multihost_dataloading.py.
Implementasi ini berfungsi di JAX multi-pengontrol dan Pathways secara terdistribusi menggunakan Python JAX API yang ditempatkan bersama.
Pembuatan Versi Jax
Rilis Pathways sangat terkait dengan versi JAX untuk memastikan kompatibilitas dan stabilitas. Untuk menghindari potensi masalah, pastikan artefak Pathways dan versi JAX Anda selaras. Setiap rilis Pathways secara jelas menentukan
versi JAX yang kompatibel melalui tag dalam bentuk jax-<version>.
Cache Kompilasi
Cache kompilasi persisten Pathways adalah fitur yang memungkinkan server Pathways menyimpan file yang dapat dieksekusi XLA yang dikompilasi di lokasi persisten, seperti Cloud Storage, untuk menghindari kompilasi yang berlebihan. Fitur ini diaktifkan secara default. Lokasi cache diteruskan sebagai flag --gcs_scratch_location ke pengelola resource dan container pekerja Pathways. Untuk meminimalkan biaya penyimpanan terkait, cache melampirkan kebijakan siklus proses ke lokasi Cloud Storage. Ada batas 50 kebijakan per bucket Cloud Storage. Oleh karena itu, sebaiknya gunakan lokasi Cloud Storage umum di semua workload.
Cache ini mirip dengan cache kompilasi JAX
yang dinonaktifkan oleh pathwaysutils.initialize() untuk workload Pathways.
Izin Cloud Storage berikut diperlukan untuk cache kompilasi:
storage.buckets.get: Untuk mengambil metadata bucket.storage.buckets.update: Penting agar Pathways dapat menyiapkan kebijakan siklus proses objek untuk menerapkan TTL bagi penghapusan cache.storage.objects.list: Untuk mencantumkan objek cache yang ada dalam bucket.storage.objects.create: Untuk menulis file yang dapat dieksekusi yang dikompilasi baru ke cache.storage.objects.get: Untuk membaca file yang dapat dieksekusi yang di-cache dari bucket.
Pembuatan profil
Anda dapat menggunakan profiler JAX untuk membuat pelacakan program JAX. Ada dua cara umum yang didukung dengan Pathways:
- Terprogram
- Mengambil profil secara terprogram dari kode JAX Anda
- Manual
- Mengambil profil sesuai permintaan setelah memulai server profiler dari kode JAX Anda
Dalam kedua kasus tersebut, profil ditulis ke bucket Cloud Storage. Akan ada beberapa file pelacakan yang dibuat di bucket Cloud Storage, yang berpotensi berada di folder stempel waktu yang berbeda, misalnya:
- Proses Python utama yang memanggil pelacakan (biasanya VM notebook Anda):
<jax-client-vm-name>.xplane.pb - Proxy IFRT Pathways:
client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Pengelola resource Pathways:
server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - Pekerja Pathways:
server.*<tpu-node-name>.xplane.pb
File pelacakan ini dapat dianalisis dengan TensorBoard dengan menjalankan perintah berikut. Untuk mengetahui informasi selengkapnya tentang TensorBoard dan semua alat pembuatan profilnya, lihat Mengoptimalkan performa TensorFlow menggunakan Profiler.
# verify trace files are present gsutil ls -l -r gs://BUCKET/PREFIX # View on tensorboard tensorboard --logdir=gs://BUCKET/PREFIX
Ganti kode berikut:
BUCKET: bucket Cloud Storage untuk menyimpan file pelacakanPREFIX: jalur dalam bucket Cloud Storage Anda untuk menyimpan file pelacakan
Pengambilan profil terprogram
Ambil profil dari dalam kode Anda. Profil disimpan di dalam
gs://<bucket>/<prefix> di bawah direktori stempel waktu
import jax import pathwaysutils pathwaysutils.initialize() jax.profiler.start_trace("gs://BUCKET/PREFIX") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
Pengambilan profil manual
Untuk mengambil informasi profil secara manual, Anda harus memulai server profiler dari kode Python Anda:
import jax
import pathwaysutils
pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)
# Your JAX code
jax.profiler.stop_server() # this is functionally a no-op
Saat server profiler berjalan, Anda dapat mengambil profil dan mengekspor data ke lokasi Cloud Storage target:
export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>
curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port>
Anda dapat menemukan informasi waktu untuk metode klien proxy IFRT seperti Compile dan Execute dalam pelacakan program Anda. Peristiwa ini, yang menjelaskan interaksi dengan server gRPC Proxy IFRT selama kompilasi dan eksekusi, muncul di thread bernama GrpcClientSessionUserFuturesWorkQueue. Dengan memeriksa thread ini dalam pelacakan, Anda dapat memperoleh insight tentang performa operasi ini.
Flag XLA
Saat menggunakan Pathways, Anda harus menetapkan flag XLA di container pathways-proxy. Anda dapat melakukannya menggunakan XPK atau PathwaysJob API.
Saat menggunakan XPK, tetapkan flag XLA seperti berikut:
--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"
Saat menggunakan PathwaysJob API, tetapkan flag XLA seperti berikut:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customFlags: - --xla_flag_1=value1 - --xla_flag_2=value2
Ganti kode berikut:
USER: nama pengguna Anda Google Cloudvalue[n]: flag XLA yang ingin Anda tetapkan
Dump HLO
Untuk mempelajari lebih lanjut input High Level Optimizer (HLO) yang diberikan ke compiler XLA, Anda dapat mengonfigurasi Pathways untuk membuang HLO ke lokasi Cloud Storage yang ditentukan sebagai berikut:
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customEnv: - name: XLA_FLAGS value: "--xla_dump_to=gs://your-bucket/your-desired-prefix/"
Langkah berikutnya
- Membuat Cluster GKE dengan Pathways
- Inferensi multi-host dengan Pathways
- Workload batch dengan Pathways
- Mode interaktif Pathways
- Pelatihan yang tangguh dengan Pathways
- Memecahkan masalah Pathways