Mengirimkan workload JAX ke Pathways

Karena sifat terdistribusi JAX dengan Pathways, beberapa operasi mungkin tidak menskalakan dengan baik karena overhead komunikasi. Meskipun Pathways meminimalkan overhead ini dengan fitur seperti pengiriman asinkron, ada beberapa hal yang perlu Anda ketahui saat mem-porting beban kerja JAX ke Pathways atau menskalakan beban kerja JAX dengan Pathways ke sejumlah besar akselerator.

Sebelum memulai

Pastikan Anda memiliki:

Indeks proses

JAX dengan Pathways memperlakukan semua perangkat di seluruh cluster Pathways Anda sebagai lokal. Hal ini menyederhanakan pengelolaan perangkat dan memungkinkan JAX memanfaatkan semua resource yang tersedia. Dalam praktiknya, hal ini berarti:

  • jax.process_index() selalu 0 untuk semua perangkat.
  • jax.devices() dan jax.local_devices() menampilkan semua perangkat TPU di seluruh tugas.

Jenis hardware dan kolokasi

Untuk mendapatkan performa terbaik, tempatkan semua komponen Pathways dan tugas pengguna di zona cloud Google Cloud yang sama. Gunakan CPU besar seperti proxy IFRT dan pengelola resource. Sebaiknya gunakan n2-standard-64 khusus yang dilengkapi dengan 64 vCPU dan memori 256 GB.

PathwaysUtils

Pathways-utils adalah repositori GitHub berbasis Python yang menyediakan utilitas dan alat penting yang memungkinkan Anda menyederhanakan deployment dan eksekusi workload JAX pada arsitektur Pathways on Cloud. Paket ini menangani adaptasi yang diperlukan untuk lingkungan cloud, sehingga developer JAX dapat berfokus pada alur kerja machine learning inti mereka dengan konfigurasi khusus platform yang minimal. Secara khusus, fitur ini menawarkan:

  • Backend JAX "proxy": backend kustom ini memungkinkan aplikasi JAX Anda menggunakan infrastruktur Pathways dengan menyetel variabel lingkungan JAX_PLATFORMS=proxy.
  • Utilitas Pembuatan Profil Terintegrasi: kemampuan pembuatan profil yang memungkinkan Anda memahami performa aplikasi Anda. Dengan menggunakan API pembuatan profil JAX standar seperti jax.profiler.start_trace dan jax.profiler.start_server, Anda dapat membuat profil tidak hanya kode JAX, tetapi juga komponen Pathways yang mendasarinya, sehingga memberikan gambaran menyeluruh tentang eksekusi dalam lingkungan cloud.
  • Checkpoint Terdistribusi dengan Orbax: handler checkpoint Orbax kustom yang memungkinkan Anda menggunakan checkpoint terdistribusi dan memulihkan checkpoint saat menggunakan library Orbax dalam lingkungan Pathways. Integrasi ini bertujuan untuk berfungsi tanpa memerlukan perubahan pada kode pembuatan checkpoint Orbax yang ada selama kode tersebut mengimpor pathwaysutils.
  • Primitif Pelatihan Elastis: menyediakan primitif pelatihan elastis dasar yang dapat Anda gunakan untuk membangun alur kerja pelatihan yang andal dan skalabel menggunakan Pathways. Primitif ini memungkinkan tugas pelatihan Anda beradaptasi secara dinamis terhadap perubahan pada resource yang tersedia, sehingga meningkatkan efisiensi dan ketahanan di lingkungan cloud.

Melakukan checkpoint

Orbax diuji secara menyeluruh dengan Pathways untuk pembuatan dan pemulihan checkpoint terdistribusi dengan Cloud Storage. Saat Anda memanggil import pathwaysutils; pathwaysutils.initialize() di train.py, ArrayHandler kustom didaftarkan yang secara efisien menangani operasi checkpoint melalui proxy IFRT, sehingga memungkinkan pekerja Pathways di akselerator untuk menyimpan dan memulihkan data secara langsung.

Python yang ditempatkan bersama

Python yang Ditempatkan Bersama adalah JAX API open source yang memungkinkan Anda menjalankan kode Python yang ditentukan pengguna langsung di host TPU atau GPU, yang lebih mudah dalam JAX multi-pengontrol. Hal ini memungkinkan tugas yang lebih intensif komputasi, seperti pemuatan data dan pembuatan titik pemeriksaan, untuk menghindari transfer data antara klien dan mesin TPU. Untuk mengonfigurasi cluster Pathways agar menjalankan API JAX Python yang ditempatkan bersama, ikuti petunjuk di README Python yang ditempatkan bersama Petunjuk ini menjelaskan cara memulai sidecar Python yang ditempatkan bersama di samping pekerja Pathways Anda.

Pemuatan data

Selama pelatihan, kita berulang kali memuat batch dari set data untuk dimasukkan ke dalam model. Memiliki pemuat data asinkron yang efisien yang membagi batch di seluruh host penting untuk menghindari kekurangan tugas akselerator. Saat menjalankan pelatihan dengan Pathways, pemuat data berjalan di VM CPU (tidak seperti VM TPU yang digunakan pada penyiapan multi-pengontrol) dan mengirimkan data ke VM TPU. Hal ini menimbulkan latensi yang lebih tinggi dalam membaca data, tetapi hal itu sebagian dimitigasi dengan membaca X jumlah batch di host CPU dan mengirimkan data yang dibaca secara asinkron ke TPU. Solusi ini sudah cukup saat berjalan dalam skala kecil hingga sedang.

Untuk performa yang optimal dalam skala besar, sebaiknya lakukan penempatan bersama pipeline data input Anda dengan menggunakan Python yang ditempatkan bersama untuk menjalankan pipeline data Anda langsung di akselerator. Hal ini menghilangkan hambatan CPU dan memanfaatkan interkoneksi cepat TPU untuk transfer data.

Anda dapat menemukan implementasi referensi untuk memigrasikan pipeline input berbasis TFDS di implementasi RemoteIterator dalam multihost_dataloading.py. Implementasi ini berfungsi di JAX dan Pathways multi-pengontrol secara terdistribusi menggunakan colocated Python JAX API.

Pembuatan Versi Jax

Rilis Pathways sangat terkait dengan versi JAX untuk memastikan kompatibilitas dan stabilitas. Untuk menghindari potensi masalah, pastikan artefak Pathways dan versi JAX Anda selaras. Setiap rilis Pathways secara jelas menentukan versi JAX yang kompatibel melalui tag dalam bentuk jax-<version>.

Cache Kompilasi

Cache kompilasi persisten Pathways adalah fitur yang memungkinkan server Pathways menyimpan executable XLA yang dikompilasi di lokasi persisten, seperti Cloud Storage, untuk menghindari kompilasi yang berlebihan. Fitur ini diaktifkan secara default. Lokasi cache diteruskan sebagai tanda --gcs_scratch_location ke container pekerja Pathways dan pengelola resource. Untuk meminimalkan biaya penyimpanan terkait, cache melampirkan kebijakan siklus proses ke lokasi Cloud Storage. Ada batas 50 kebijakan per bucket Cloud Storage. Oleh karena itu, sebaiknya gunakan lokasi Cloud Storage umum di semua beban kerja.

Cache ini mirip dengan cache kompilasi JAX yang dinonaktifkan oleh pathwaysutils.initialize() untuk workload Pathways.

Pembuatan profil

Anda dapat menggunakan profiler JAX untuk membuat rekaman aktivitas program JAX. Ada dua cara umum yang didukung dengan Pathways:

  • Terprogram
    • Mengambil profil secara terprogram dari kode JAX Anda
  • Manual
    • Merekam profil sesuai permintaan setelah memulai server profiler dari kode JAX Anda

Dalam kedua kasus tersebut, profil ditulis ke bucket Cloud Storage. Akan ada beberapa file rekaman aktivitas yang dibuat di bucket Cloud Storage yang mungkin berada di folder stempel waktu yang berbeda, misalnya:

  • Proses Python utama yang memanggil rekaman aktivitas (biasanya VM notebook Anda): <jax-client-vm-name>.xplane.pb
  • Proxy IFRT Pathways: client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pengelola Resource Pathways: server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pekerja Pathways: server.*<tpu-node-name>.xplane.pb

File rekaman aktivitas ini dapat dianalisis dengan TensorBoard dengan menjalankan perintah berikut. Untuk mengetahui informasi selengkapnya tentang TensorBoard dan semua alat profiling-nya, lihat Mengoptimalkan performa TensorFlow menggunakan Profiler.

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

Ganti kode berikut:

  • BUCKET : bucket Cloud Storage untuk menyimpan file rekaman aktivitas
  • PREFIX: jalur dalam bucket Cloud Storage Anda untuk menyimpan file rekaman aktivitas

Pengambilan profil terprogram

Merekam profil dari dalam kode Anda. Profil disimpan di dalam gs://<bucket>/<prefix> di bawah direktori stempel waktu

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Pengambilan profil manual

Untuk merekam informasi profil secara manual, Anda harus memulai server profiler dari kode Python Anda:

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op

Saat server profiler berjalan, Anda dapat merekam profil dan mengekspor data ke lokasi Cloud Storage target:

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port

Anda dapat menemukan informasi pengaturan waktu untuk metode klien proxy IFRT seperti Compile dan Execute dalam rekaman aktivitas program Anda. Peristiwa ini, yang menjelaskan interaksi dengan server gRPC IFRT Proxy selama kompilasi dan eksekusi, muncul di thread bernama GrpcClientSessionUserFuturesWorkQueue. Dengan memeriksa thread ini dalam rekaman aktivitas, Anda dapat memperoleh insight tentang performa operasi ini.

Flag XLA

Saat menggunakan Pathways, Anda perlu menyetel flag XLA di container pathways-proxy. Anda dapat melakukannya menggunakan XPK atau PathwaysJob API.

Saat menggunakan XPK, tetapkan flag XLA seperti berikut:

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

Saat menggunakan PathwaysJob API, tetapkan flag XLA seperti berikut:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

Ganti kode berikut:

  • USER : nama pengguna Google Cloud Anda
  • value[n]: tanda XLA yang ingin Anda tetapkan

Dump HLO

Untuk mempelajari lebih dalam input Pengoptimal Tingkat Tinggi (HLO) yang diberikan ke kompiler XLA, Anda dapat mengonfigurasi Pathways untuk mengekspor HLO ke lokasi Cloud Storage yang ditentukan sebagai berikut:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"

Langkah berikutnya