將 JAX 工作負載移植到 Pathways

由於 JAX with Pathways 具有分散式特性,部分作業可能會因通訊負擔而無法順利擴充。雖然 Pathways 會透過非同步調度等功能,盡量減少這些額外負擔,但將 JAX 工作負載移植到 Pathways 時,或將 JAX 與 Pathways 工作負載擴充至大量加速器時,仍有幾件事需要注意。

事前準備

請確認您已備妥以下項目:

程序索引

使用 Pathways 的 JAX 會將 Pathways 叢集中的所有裝置視為本機。 這可簡化裝置管理作業,並讓 JAX 運用所有可用資源。實務上,這表示:

  • 所有裝置的 jax.process_index() 一律為 0。
  • jax.devices()jax.local_devices() 會傳回整個作業中的所有 TPU 裝置。

硬體類型和主機代管

為獲得最佳效能,請將所有 Pathways 元件和使用者工作放在相同的 Google Cloud 雲端區域。使用大型 CPU,例如 IFRT Proxy 和資源管理工具。建議至少使用專屬 n2-standard-64,內含 64 個 vCPU 和 256 GB 記憶體。

PathwaysUtils

Pathways-utils 是以 Python 為基礎的 GitHub 存放區,提供實用工具,可簡化在 Pathways on Cloud 架構上部署及執行 JAX 工作負載的程序。這個套件會處理雲端環境的必要調整,讓 JAX 開發人員專注於核心機器學習工作流程,並盡量減少平台專屬設定。具體來說,這項服務提供:

  • 「Proxy」JAX 後端:這個自訂後端可讓 JAX 應用程式透過設定 JAX_PLATFORMS=proxy 環境變數,使用 Pathways 基礎架構。
  • 整合式剖析公用程式:剖析功能可協助您瞭解應用程式效能。使用 jax.profiler.start_tracejax.profiler.start_server 等標準 JAX 剖析 API,您不僅可以剖析 JAX 程式碼,還能剖析基礎 Pathways 元件,全面瞭解雲端環境中的執行作業。
  • 使用 Orbax 進行分散式檢查點設定:自訂 Orbax 檢查點處理常式,可讓您在 Pathways 環境中使用 Orbax 程式庫時,使用分散式檢查點並還原檢查點。這項整合功能旨在運作時,無須變更現有的 Orbax 檢查點程式碼 (只要匯入 pathwaysutils 即可)。
  • 彈性訓練基本元素:提供基礎彈性訓練基本元素,您可以使用這些元素,透過 Pathways 建構穩健且可擴充的訓練工作流程。這些基本元素可讓訓練工作動態調整可用資源的變化,進而提升雲端環境的效率和韌性。

查核點機制

Orbax 經過 Pathways 的完整測試,可搭配 Cloud Storage 進行分散式檢查點作業和還原。在 train.py 中呼叫 import pathwaysutils; pathwaysutils.initialize() 時,系統會註冊自訂 ArrayHandler,透過 IFRT 代理程式有效處理檢查點作業,讓加速器上的 Pathways 工作人員直接儲存及還原資料。

共置 Python

共置 Python 是開放原始碼 JAX API,可讓您直接在 TPU 或 GPU 主機上執行使用者指定的 Python 程式碼,在多控制器 JAX 中更為簡單。這樣一來,您就能執行更多需要大量運算資源的作業,例如資料載入和檢查點作業,避免在用戶端和 TPU 電腦之間傳輸資料。如要設定 Pathways 叢集,以執行共置的 Python JAX API,請按照共置 Python README 中的操作說明進行。這些操作說明會說明如何啟動共置的 Python Sidecar,以及 Pathways 工作人員。

資料載入

在訓練期間,我們會從資料集重複載入批次,並將其饋送至模型。擁有可將批次分散到主機的有效非同步資料載入器,對於避免加速器工作量不足非常重要。使用 Pathways 執行訓練時,資料載入器會在 CPU VM 上執行 (與多重控制器設定中使用的 TPU VM 不同),並將資料傳送至 TPU VM。這樣會導致讀取資料的延遲時間較長,但 CPU 主機上預先讀取 X 個批次,並將讀取的資料非同步傳送至 TPU,可部分緩解這個問題。如果執行規模不大,這個解決方案就已足夠。

為獲得最佳大規模效能,我們強烈建議使用共置 Python,在加速器上直接執行資料管道,藉此共置輸入資料管道。這可消除 CPU 瓶頸,並利用 TPU 的快速互連功能傳輸資料。

您可以在 multihost_dataloading.pyRemoteIterator 實作中,找到遷移以 TFDS 為基礎的輸入管道參考實作。這項實作項目可透過共置 Python JAX API,以分散式方式在多控制器 JAX 和 Pathways 上運作。

Jax 版本管理

為確保相容性和穩定性,Pathways 版本與 JAX 版本緊密結合。為避免發生潛在問題,請確認 Pathways 構件和 JAX 版本一致。每個 Pathways 版本都會透過 jax-<version> 形式的標記,清楚指定相容的 JAX 版本。

編譯快取

Pathways 持續性編譯快取功能可讓 Pathways 伺服器將編譯的 XLA 可執行檔儲存在持續性位置 (例如 Cloud Storage),避免重複編譯。這項功能預設為啟用。快取位置會以 --gcs_scratch_location 標記的形式傳遞至資源管理員和 Pathways 工作站容器。為盡量減少相關儲存空間費用,快取會將生命週期政策附加至 Cloud Storage 位置。每個 Cloud Storage 值區最多可有 50 項政策。因此,我們建議在所有工作負載中使用通用的 Cloud Storage 位置。

這個快取類似於 JAX 編譯快取,但 pathwaysutils.initialize() 會針對 Pathways 工作負載停用這個快取。

分析

您可以使用 JAX 分析器產生 JAX 程式的追蹤記錄。路徑支援兩種常見方式:

  • 程式輔助
    • 以程式輔助方式從 JAX 程式碼擷取設定檔
  • 手動
    • 從 JAX 程式碼啟動分析器伺服器後,視需要擷取設定檔

在這兩種情況下,剖析檔都會寫入 Cloud Storage bucket。Cloud Storage bucket 中可能會建立多個追蹤記錄檔,例如:

  • 叫用追蹤記錄的主要 Python 程序 (通常是筆記本 VM): <jax-client-vm-name>.xplane.pb
  • Pathways IFRT 代理:client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • 路徑資源管理員:server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pathways 工作者:server.*<tpu-node-name>.xplane.pb

執行下列指令,即可使用 TensorBoard 分析這些追蹤記錄檔。如要進一步瞭解 TensorBoard 和所有分析工具,請參閱「使用分析器將 TensorFlow 效能最佳化」。

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

更改下列內容:

  • BUCKET:用於儲存追蹤記錄檔案的 Cloud Storage bucket
  • PREFIX:Cloud Storage bucket 內的路徑,用於儲存追蹤記錄檔案

以程式輔助方式擷取設定檔

從程式碼內部擷取設定檔。設定檔會儲存在 gs://<bucket>/<prefix> 的時間戳記目錄中

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

手動擷取設定檔

如要手動擷取剖析資訊,您必須從 Python 程式碼啟動剖析器伺服器:

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op

剖析器伺服器執行期間,您可以擷取剖析結果,並將資料匯出至目標 Cloud Storage 位置:

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port

您可以在程式的追蹤記錄中,找到 IFRT 代理用戶端方法 (例如 CompileExecute) 的時間資訊。這些事件會詳細說明編譯和執行期間與 IFRT Proxy gRPC 伺服器的互動,並顯示在名為 GrpcClientSessionUserFuturesWorkQueue 的執行緒上。檢查追蹤記錄中的這個執行緒,即可深入瞭解這些作業的效能。

XLA 旗標

使用 Pathways 時,您需要在 pathways-proxy 容器中設定 XLA 旗標。您可以使用 XPK 或 PathwaysJob API 執行這項操作。

使用 XPK 時,請設定 XLA 標記,如下所示:

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

使用 PathwaysJob API 時,請設定 XLA 旗標,如下所示:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

更改下列內容:

  • USER:您的 Google Cloud 使用者名稱
  • value[n]:要設定的 XLA 旗標

HLO 傾印

如要深入瞭解提供給 XLA 編譯器的高階最佳化工具 (HLO) 輸入內容,您可以設定 Pathways,將 HLO 傾印至指定的 Cloud Storage 位置,方法如下:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"

後續步驟