Mengirimkan workload JAX ke Pathways

Karena sifat JAX yang terdistribusi dengan Pathways, beberapa operasi mungkin tidak dapat diskalakan dengan baik karena overhead komunikasi. Meskipun Pathways meminimalkan overhead ini dengan fitur seperti pengiriman asinkron, ada beberapa hal yang perlu Anda ketahui saat mem-porting workload JAX ke Pathways atau menskalakan workload JAX dengan Pathways ke sejumlah besar akselerator.

Sebelum memulai

Pastikan Anda memiliki:

Indeks proses

JAX dengan Pathways memperlakukan semua perangkat di seluruh cluster Pathways Anda sebagai lokal. Hal ini menyederhanakan pengelolaan perangkat dan memungkinkan JAX menggunakan semua resource yang tersedia. Dalam praktiknya, hal ini berarti:

  • jax.process_index() selalu 0 untuk semua perangkat.
  • jax.devices() dan jax.local_devices() menampilkan semua perangkat TPU di seluruh tugas.

Jenis hardware dan kolokasi

Untuk mendapatkan performa terbaik, tempatkan semua komponen Pathways dan tugas pengguna di zona cloud yang Google Cloud sama. Gunakan CPU besar seperti proxy IFRT dan pengelola resource. Sebaiknya gunakan setidaknya n2-standard-64 khusus yang dilengkapi dengan 64 vCPU dan memori 256 GB.

PathwaysUtils

Pathways-utils adalah repositori GitHub berbasis Python yang menyediakan utilitas dan alat penting yang memungkinkan Anda menyederhanakan deployment dan eksekusi workload JAX pada arsitektur Pathways di Cloud. Paket ini menangani adaptasi yang diperlukan untuk lingkungan cloud, sehingga developer JAX dapat berfokus pada alur kerja machine learning inti mereka dengan konfigurasi khusus platform yang minimal. Secara khusus, paket ini menawarkan:

  • Backend JAX "proxy": backend kustom ini memungkinkan aplikasi JAX Anda menggunakan infrastruktur Pathways dengan menetapkan variabel lingkungan JAX_PLATFORMS=proxy.
  • Utilitas Pembuatan Profil Terintegrasi: kemampuan pembuatan profil yang memungkinkan Anda memahami performa aplikasi. Dengan menggunakan JAX profiling API standar seperti jax.profiler.start_trace dan jax.profiler.start_server, Anda dapat membuat profil tidak hanya kode JAX, tetapi juga komponen Pathways yang mendasarinya, sehingga memberikan tampilan eksekusi yang holistik dalam lingkungan cloud.
  • Checkpointing Terdistribusi dengan Orbax: pengendali checkpoint Orbax kustom yang memungkinkan Anda menggunakan checkpoint terdistribusi dan memulihkan checkpoint saat menggunakan library Orbax dalam lingkungan Pathways. Integrasi ini bertujuan untuk berfungsi tanpa memerlukan perubahan apa pun pada kode checkpointing Orbax yang ada selama kode tersebut mengimpor pathwaysutils.
  • Primitif Pelatihan Elastis: menyediakan primitif pelatihan elastis dasar yang dapat Anda gunakan untuk membuat alur kerja pelatihan yang kuat dan dapat diskalakan menggunakan Pathways. Primitif ini memungkinkan tugas pelatihan Anda beradaptasi secara dinamis terhadap perubahan resource yang tersedia, sehingga meningkatkan efisiensi dan ketahanan di lingkungan cloud.

Checkpointing

Orbax diuji secara menyeluruh dengan Pathways untuk checkpointing dan pemulihan terdistribusi dengan Cloud Storage. Saat Anda menetapkan variabel lingkungan ENABLE_PATHWAYS_PERSISTENCE=1 dan memanggil import pathwaysutils; pathwaysutils.initialize() di train.py, ArrayHandler kustom akan didaftarkan yang menangani operasi checkpoint secara efisien melalui proxy IFRT, sehingga pekerja Pathways di akselerator dapat langsung menyimpan dan memulihkan data.

Python yang ditempatkan bersama

Python yang ditempatkan bersama adalah JAX API open source yang memungkinkan Anda menjalankan kode Python yang ditentukan pengguna langsung di host TPU atau GPU, yang lebih mudah di multi-pengontrol JAX. Hal ini memungkinkan tugas yang lebih intensif komputasi, seperti pemuatan data dan checkpointing, untuk menghindari transfer data antara klien dan mesin TPU. Untuk mengonfigurasi cluster Pathways agar menjalankan Python JAX API yang ditempatkan bersama, ikuti petunjuk di README Python yang ditempatkan bersama. Petunjuk ini menjelaskan cara memulai sidecar Python yang ditempatkan bersama di samping pekerja Pathways Anda.

Pemuatan data

Selama pelatihan, kita berulang kali memuat batch dari set data untuk dimasukkan ke dalam model. Memiliki pemuat data asinkron yang efisien yang membagi batch di seluruh host penting untuk menghindari akselerator yang kekurangan pekerjaan. Saat menjalankan pelatihan dengan Pathways, pemuat data berjalan di VM CPU (tidak seperti VM TPU yang digunakan pada penyiapan multi-pengontrol) dan mengirimkan data ke VM TPU. Hal ini menyebabkan latensi yang lebih tinggi dalam membaca data, tetapi hal tersebut dikurangi sebagian dengan membaca X batch di host CPU dan mengirimkan data yang dibaca secara asinkron ke TPU. Solusi ini cukup saat berjalan dalam skala kecil hingga menengah.

Untuk performa optimal dalam skala besar, sebaiknya tempatkan bersama pipeline data input Anda dengan menggunakan Python yang ditempatkan bersama untuk menjalankan pipeline data Anda langsung di akselerator. Hal ini akan menghilangkan bottleneck CPU dan memanfaatkan interkoneksi cepat TPU untuk transfer data.

Anda dapat menemukan implementasi referensi untuk memigrasikan pipeline input berbasis TFDS dalam implementasi RemoteIterator di multihost_dataloading.py. Implementasi ini berfungsi di JAX multi-pengontrol dan Pathways secara terdistribusi menggunakan Python JAX API yang ditempatkan bersama.

Pembuatan Versi Jax

Rilis Pathways sangat terkait dengan versi JAX untuk memastikan kompatibilitas dan stabilitas. Untuk menghindari potensi masalah, pastikan artefak Pathways dan versi JAX Anda selaras. Setiap rilis Pathways secara jelas menentukan versi JAX yang kompatibel melalui tag dalam bentuk jax-<version>.

Cache Kompilasi

Cache kompilasi persisten Pathways adalah fitur yang memungkinkan server Pathways menyimpan file yang dapat dieksekusi XLA yang dikompilasi di lokasi persisten, seperti Cloud Storage, untuk menghindari kompilasi yang berlebihan. Fitur ini diaktifkan secara default. Lokasi cache diteruskan sebagai flag --gcs_scratch_location ke pengelola resource dan container pekerja Pathways. Untuk meminimalkan biaya penyimpanan terkait, cache melampirkan kebijakan siklus proses ke lokasi Cloud Storage. Ada batas 50 kebijakan per bucket Cloud Storage. Oleh karena itu, sebaiknya gunakan lokasi Cloud Storage umum di semua workload.

Cache ini mirip dengan cache kompilasi JAX yang dinonaktifkan oleh pathwaysutils.initialize() untuk workload Pathways.

Izin Cloud Storage berikut diperlukan untuk cache kompilasi:

  • storage.buckets.get: Untuk mengambil metadata bucket.
  • storage.buckets.update: Penting agar Pathways dapat menyiapkan kebijakan siklus proses objek untuk menerapkan TTL bagi penghapusan cache.
  • storage.objects.list: Untuk mencantumkan objek cache yang ada dalam bucket.
  • storage.objects.create: Untuk menulis file yang dapat dieksekusi yang dikompilasi baru ke cache.
  • storage.objects.get: Untuk membaca file yang dapat dieksekusi yang di-cache dari bucket.

Pembuatan profil

Anda dapat menggunakan profiler JAX untuk membuat pelacakan program JAX. Ada dua cara umum yang didukung dengan Pathways:

  • Terprogram
    • Mengambil profil secara terprogram dari kode JAX Anda
  • Manual
    • Mengambil profil sesuai permintaan setelah memulai server profiler dari kode JAX Anda

Dalam kedua kasus tersebut, profil ditulis ke bucket Cloud Storage. Akan ada beberapa file pelacakan yang dibuat di bucket Cloud Storage, yang berpotensi berada di folder stempel waktu yang berbeda, misalnya:

  • Proses Python utama yang memanggil pelacakan (biasanya VM notebook Anda): <jax-client-vm-name>.xplane.pb
  • Proxy IFRT Pathways: client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pengelola resource Pathways: server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pekerja Pathways: server.*<tpu-node-name>.xplane.pb

File pelacakan ini dapat dianalisis dengan TensorBoard dengan menjalankan perintah berikut. Untuk mengetahui informasi selengkapnya tentang TensorBoard dan semua alat pembuatan profilnya, lihat Mengoptimalkan performa TensorFlow menggunakan Profiler.

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

Ganti kode berikut:

  • BUCKET : bucket Cloud Storage untuk menyimpan file pelacakan
  • PREFIX: jalur dalam bucket Cloud Storage Anda untuk menyimpan file pelacakan

Pengambilan profil terprogram

Ambil profil dari dalam kode Anda. Profil disimpan di dalam gs://<bucket>/<prefix> di bawah direktori stempel waktu

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

Pengambilan profil manual

Untuk mengambil informasi profil secara manual, Anda harus memulai server profiler dari kode Python Anda:

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functionally a no-op

Saat server profiler berjalan, Anda dapat mengambil profil dan mengekspor data ke lokasi Cloud Storage target:

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port>

Anda dapat menemukan informasi waktu untuk metode klien proxy IFRT seperti Compile dan Execute dalam pelacakan program Anda. Peristiwa ini, yang menjelaskan interaksi dengan server gRPC Proxy IFRT selama kompilasi dan eksekusi, muncul di thread bernama GrpcClientSessionUserFuturesWorkQueue. Dengan memeriksa thread ini dalam pelacakan, Anda dapat memperoleh insight tentang performa operasi ini.

Flag XLA

Saat menggunakan Pathways, Anda harus menetapkan flag XLA di container pathways-proxy. Anda dapat melakukannya menggunakan XPK atau PathwaysJob API.

Saat menggunakan XPK, tetapkan flag XLA seperti berikut:

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

Saat menggunakan PathwaysJob API, tetapkan flag XLA seperti berikut:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

Ganti kode berikut:

  • USER : nama pengguna Anda Google Cloud
  • value[n]: flag XLA yang ingin Anda tetapkan

Dump HLO

Untuk mempelajari lebih lanjut input High Level Optimizer (HLO) yang diberikan ke compiler XLA, Anda dapat mengonfigurasi Pathways untuk membuang HLO ke lokasi Cloud Storage yang ditentukan sebagai berikut:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-bucket/your-desired-prefix/"

Langkah berikutnya