透過 Pathways 進行彈性訓練

路徑可透過下列方式提供復原能力優勢:

  • 暫停/繼續:可容忍預先通知等計畫性中斷,不需使用者編寫任何自訂搶占處理程式碼。
  • 彈性訓練:容許發生非預期的硬體故障,不會導致用戶端當機,但使用者必須編寫模型專屬的復原程式碼。

    事前準備

    請確認您已備妥以下項目:

    暫停/繼續

    一般來說,GKE 會先將搶占通知傳送至加速器 Pod,再搶占該 Pod。所有雲端部署作業都會預設啟用 Pathways 搶占容許度,而 Pathways 加速器工作會監聽這些通知。

    收到搶占通知時,Pathways 會先判斷目前的工作負載是否可還原,也就是 Pathways 是否能以透明方式儲存及還原工作負載。如果是,系統會嘗試將機器學習工作負載的目前狀態寫入 Cloud Storage 等永久儲存空間,然後在 GKE 驅逐加速器工作前,以透明方式暫停工作負載。稍後 GKE 重新排定作業時,Pathways 會讀取保存的狀態,繼續執行 ML 工作負載。

    如果工作負載無法還原,路徑會關閉加速器工作,並在設定 Elastic training 時,將失敗轉送至工作。如果未設定彈性訓練,GKE 會根據 JobSet 重啟政策重新啟動整個工作負載。

    使用 JAX 定義的典型機器學習工作負載會依賴無狀態的 Pathways XLA 元件,這些元件可使用高頻寬記憶體 (HBM) 快照還原。某些 ML 工作負載 (例如使用 JAX 共置 Python API 定義的工作負載) 依賴有狀態的 Pathways 元件,這些元件無法還原。

    彈力帶訓練

    彈性訓練可讓訓練工作在發生硬體故障時繼續執行。這項功能結合了 Pathways 系統功能和使用者定義的模型復原邏輯:

    • 偵測失敗:發生硬體故障時 (例如 TPU 工作人員當機),Pathways 系統會偵測到這項問題,並在下次存取位於該硬體上的資料時,透過例外狀況通知使用者的訓練作業。這項通知不會導致工作負載當機,而是讓程式碼處理通知,並重新設定資源,以便繼續處理或正常結束。
    • 使用者定義的彈性處理常式:使用者模型程式碼必須能夠處理這項例外狀況。這就是「特定機型復原」的意義。
      • 快照:最常見的做法是定期儲存模型狀態的快照。發生失敗時,您可以從最近的快照載入,繼續訓練。
      • 重新設定:您可能需要重新設定訓練工作,以調整可用的切片數量。舉例來說,如果其中一個切片停止運作,您可能會將有效切片數量減少一個,直到有替代切片為止。詳情請參閱「彈性處理常式」。
      • 資料/運算圖更新:如果運算可用的裝置數量有任何變更,您的程式碼需要視需要重新建立運算圖,以處理這些變更。這可能需要重新分割資料或重新編譯模型。
    • 路徑在復原程序中的角色:路徑提供基本元素,支援使用者定義的重新設定:
      • 切片替換:如果替換失敗的切片,用戶端可以在新切片可用時收到通知。您的程式碼隨後可以重新設定,改用這個新切片。
      • 透明復原:路徑會處理復原的低階詳細資料,例如重新建立與叢集健康部分的連線。
    • pathwaysutils 中的公用程式:在 pathways-utils 中定義的一組 Pathways 公用程式。

    實作彈性處理常式

    您必須撰寫的大部分程式碼,都會位於使用者定義的彈性處理常式中。這個處理常式會對彈性事件 (例如 TPU 節點無法使用) 做出反應,方法是重新建立網格並重新初始化訓練迴圈。

    每個工作負載都不盡相同,彈性處理常式的複雜度可能會隨著工作負載的複雜度而擴大。處理常式的輸入和輸出內容應為重新初始化訓練迴圈所需的最低引數和傳回值。

    def elastic_handler(elastic_utils, *args, **kwargs):
      mesh = initialize_mesh(**kwargs["mesh_kwargs"])
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"])
    
      step, snapshot = elastic_utils.get_next_snapshot()
      state = initial_state.replace(**snapshot)
    
      return state, step, mesh, jitted_train_step, other_variables
    

    更新訓練迴圈

    您需要對訓練迴圈進行下列變更:

    1. 建立彈性管理員
    2. 將訓練迴圈包裝在 try-except 區塊中,處理 jax.errors.JaxRuntimeError
    3. jax.errors.JaxRuntimeError 處理常式中,呼叫 maybe_reshard_down。 如果錯誤與彈性事件相關,彈性管理員會重新分片,否則會重新引發錯誤。
    4. 在訓練迴圈結束時呼叫 maybe_snapshotmaybe_reshard_up
    import pathwaysutils
    from pathwaysutils.elastic import manager
    
    pathwaysutils.initialize()
    
    def initialize_mesh(**kwargs):
      ...
    
    
    def initialize_training_loop(**kwargs):
      ...
    
    
    def train_loop(
        final_step,
        elastic_manager,
        mesh_kwargs,
        initialize_training_loop_kwargs,
    ):
      mesh = initialize_mesh(**mesh_kwargs)
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **initialize_training_loop_kwargs)
    
      step = initial_step
      while step < final_step:
        try:
          state = jitted_train_step(state)
    
          elastic_manager.maybe_snapshot(step=step, snapshot=state)
          handler_returns = elastic_manager.maybe_reshard_up(
              step=step,
              snapshot=state,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
          step += 1
        except jax.errors.JaxRuntimeError as error:
          handler_returns = elastic_manager.maybe_reshard_down(
              error=error,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
    
      return state
    
    
    def main():
      elastic_manager = manager.Manager(
          devices=jax.devices(),
          snapshot_period=10,
          snapshot_buffer_size=1,
          reshard_check_period=5,
          max_elastic_down_event_count=10,
          max_reshard_retry_count=3,
      )
    
      train_loop(100, elastic_manager, {}, {})
    

    設定彈性管理工具

    彈性管理員可透過幾種方式設定。快照頻率取決於快照週期。快照期會影響因彈性事件而損失的平均步數。重新分片檢查週期會決定訓練迴圈輪詢切片可用性的頻率。max_elastic_down_event_count 可讓您設定訓練迴圈支援的切片遺失彈性事件數量。max_reshard_retry_count 指定彈性管理員應重試重新分片的次數。管理員是單例物件,因此只能建立一次。

    快照

    根據彈性管理工具設定,函式可能會將資料快照到主機記憶體,供彈性事件期間的彈性處理常式使用。

    減少分片

    捕捉到 jax.errors.JaxRuntimeError 後,路徑會檢查錯誤是否是由於切片遺失而導致彈性事件所致。如果是,系統會以迴圈呼叫彈性處理常式,直到成功或達到重試次數上限為止。如果錯誤並非因彈性事件而起,系統會再次引發錯誤。彈性處理常式的傳回值會傳遞給呼叫端。

    增加分片

    根據彈性管理工具設定,如果沒有可用的切片,路徑會檢查是否有其他切片可用。如果是,系統會立即儲存快照 (如果目前步驟的快照尚未建立),並在迴圈中呼叫彈性處理常式,直到成功或達到重試次數上限為止。如果發生重新分片,彈性處理常式的傳回值會傳遞至呼叫端。否則會傳回 None

    熱插拔

    熱插拔是指 GKE JobSet API 的一項功能,可讓優先順序較高的工作快速接管優先順序較低的工作的資源,盡量減少停機時間,並確保更快復原。

    建立 JobSet 時,GKE 會根據 JobSet 設定,在多個切片中排定工作負載。如果一或多個切片發生硬體故障,受影響的 Pod 會標示為失敗。重新排定這個 Jobset 時,如果您選擇在 GKE 叢集中保留可供優先順序較低的工作使用的備用 Slice,JobSet 系統會將優先順序較高工作失敗 Slice 的工作負載,重新對應至同一 GKE 叢集中優先順序較低工作使用的備用 Slice。重新對應通常不到一分鐘即可完成。

    JobSet 重新啟動後,在下列情況下可能會發生熱插拔:

    1. 預設模式:如果同一叢集內有閒置的備用 TPU 節點,Kubernetes 排程器會優先將重新啟動的作業排程至這些節點,而不是等待修復失敗的節點。這樣可以加快復原速度。
    2. 異質工作負載:在執行多個工作負載的叢集中,如果設定了 Kubernetes PriorityClass,重新啟動的 JobSet 可能會觸發熱交換。如果重新啟動的工作與優先順序較低的工作資源相符,Kubernetes 會先佔用優先順序較低的工作,讓優先順序較高的工作立即啟動。舉例來說,您可以使用 PriorityClass,為 Pathways 工作站 Pod 設定不同的優先順序。

    如要在叢集中使用優先順序,請定義優先順序類別,例如:

    kind: PriorityClass
    metadata:
      name: high-prior-job
    value: 2000
    globalDefault: false
    description: "This priority class should be used for high priority job."
    

    將這個 YAML 套用至 GKE 叢集:

    kubectl apply -f high-prior-job.yaml
    

    接著,將下列文字新增至 pathways-worker Pod 的 podspec,將新的優先順序類別附加至 Pathways 工作人員工作。

    priorityClassName: high-prior-job
    

    後續步驟