Formazione resiliente con Pathways

Pathways offre vantaggi in termini di resilienza nei seguenti modi:

  • Sospensione-Ripresa: tolleranza in caso di interruzioni pianificate, come gli avvisi di prerilascio, senza che l'utente debba scrivere codice personalizzato per la gestione del prerilascio.
  • Addestramento elastico: tolleranza in caso di guasti hardware non pianificati senza causare l'arresto anomalo del client, ma che richiedono agli utenti di scrivere codice di ripristino specifico del modello.

Prima di iniziare

Assicurati di avere:

Sospensione-Ripresa

In genere, GKE invia un avviso di prerilascio a un pod dell'acceleratore prima che il pod venga prerilasciato. La tolleranza al prerilascio di Pathways è abilitata per impostazione predefinita in tutti i deployment cloud e i job dell'acceleratore Pathways sono in ascolto di questi avvisi.

Quando arriva un avviso di prerilascio, Pathways determina innanzitutto se il carico di lavoro corrente è ripristinabile, ovvero se Pathways può salvare e ripristinare il carico di lavoro in modo trasparente. In caso affermativo, tenta di sospendere in modo trasparente il carico di lavoro ML scrivendo il suo stato attuale in uno spazio di archiviazione permanente come Cloud Storage prima che GKE rimuova i job dell'acceleratore. Quando GKE riprogramma i job in un secondo momento, Pathways riprende il carico di lavoro ML leggendo il suo stato persistente.

Se il carico di lavoro non è ripristinabile, Pathways arresta il job dell'acceleratore e inoltra l'errore al job se è configurato l'addestramento elastico. Se l'addestramento elastico non è configurato, GKE riavvia l'intero carico di lavoro in base alle norme di riavvio di JobSet.

I carichi di lavoro ML tipici definiti utilizzando JAX si basano su componenti XLA di Pathways senza stato che possono essere ripristinati utilizzando uno snapshot della memoria a larghezza di banda elevata (HBM). Alcuni carichi di lavoro ML , come quelli definiti utilizzando l'API Python colocalizzata JAX , si basano su componenti Pathways con stato; questi non sono ripristinabili.

Addestramento elastico

L'addestramento elastico consente al job di addestramento di continuare anche in caso di guasti hardware. Ciò si ottiene grazie a una combinazione di funzionalità del sistema Pathways e logica di ripristino del modello definita dall'utente:

  • Rilevamento di errori: quando si verifica un guasto hardware (ad esempio, un worker TPU si arresta in modo anomalo), il sistema Pathways lo rileva e invia una notifica al job di addestramento dell'utente tramite un'eccezione la volta successiva che si accede ai dati che si trovavano sull'hardware. Questa notifica non causa l'arresto anomalo del carico di lavoro; consente al codice di gestire la notifica e riconfigurare le risorse per continuare l'elaborazione o uscire normalmente.
  • Gestore di elasticità definito dall'utente: il codice del modello dell'utente deve essere in grado di gestire questa eccezione. Questo è ciò che lo rende un "recupero specifico del modello".
    • Snapshot: l'approccio più comune consiste nel salvare periodicamente gli snapshot dello stato del modello. Quando si verifica un errore, puoi caricare l'ultimo snapshot per riprendere l'addestramento.
    • Riconfigurazione: probabilmente dovrai riconfigurare il job di addestramento per adattarlo al numero di slice disponibili. Ad esempio, se una slice smette di funzionare, puoi ridurre di uno il numero di slice attive finché non è disponibile una sostituzione. Per ulteriori informazioni, consulta Gestore elastico.
    • Aggiornamenti del grafico di dati/calcoli: il codice deve gestire eventuali modifiche al numero di dispositivi disponibili per il calcolo ricreando il grafico di calcolo in base alle esigenze. Ciò potrebbe comportare la ripartizione dei dati o la ricompilazione del modello.
  • Ruolo di Pathways nel ripristino: Pathways fornisce le primitive per supportare la riconfigurazione definita dall'utente:
    • Sostituzione delle slice: se una slice non funzionante viene sostituita, il client può essere informato quando la nuova slice è disponibile. Il codice può quindi essere riconfigurato per utilizzare questa nuova slice.
    • Ripristino trasparente: Pathways gestisce i dettagli di livello inferiore del ripristino, come il ristabilimento delle connessioni alle parti integre del cluster.
  • Utilità in pathwaysutils: un insieme di utilità Pathways definite in pathways-utils.

Implementare un gestore elastico

La maggior parte del codice che dovrai scrivere sarà in un gestore elastico definito dall'utente. Questo gestore reagisce agli eventi elastici (ad esempio, quando una slice TPU diventa non disponibile) ricreando la mesh e reinizializzando il ciclo di addestramento.

Ogni carico di lavoro è unico. La complessità del gestore elastico può aumentare con la complessità del carico di lavoro. Gli input e gli output del gestore devono essere gli argomenti e i valori restituiti minimi necessari per reinizializzare il ciclo di addestramento.

def elastic_handler(elastic_utils, *args, **kwargs):
  mesh = initialize_mesh(**kwargs["mesh_kwargs"])
  initial_state, initial_step, jitted_train_step, other_variables =
      initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"])

  step, snapshot = elastic_utils.get_next_snapshot()
  state = initial_state.replace(**snapshot)

  return state, step, mesh, jitted_train_step, other_variables

Aggiornare il ciclo di addestramento

Devi apportare le seguenti modifiche al ciclo di addestramento:

  1. Crea un gestore elastico
  2. Racchiudi il ciclo di addestramento all'interno di blocchi try-except che gestiscono jax.errors.JaxRuntimeErrors
  3. All'interno del gestore jax.errors.JaxRuntimeError, chiama maybe_reshard_down. Il gestore elastico eseguirà il resharding se l'errore è correlato a un evento elastico o lo riporterà.
  4. Chiama maybe_snapshot e maybe_reshard_up alla fine del ciclo di addestramento
import pathwaysutils
from pathwaysutils.elastic import manager

pathwaysutils.initialize()

def initialize_mesh(**kwargs):
  ...


def initialize_training_loop(**kwargs):
  ...


def train_loop(
    final_step,
    elastic_manager,
    mesh_kwargs,
    initialize_training_loop_kwargs,
):
  mesh = initialize_mesh(**mesh_kwargs)
  initial_state, initial_step, jitted_train_step, other_variables =
      initialize_training_loop(mesh, **initialize_training_loop_kwargs)

  step = initial_step
  while step < final_step:
    try:
      state = jitted_train_step(state)

      elastic_manager.maybe_snapshot(step=step, snapshot=state)
      handler_returns = elastic_manager.maybe_reshard_up(
          step=step,
          snapshot=state,
          elastic_handler=elastic_handler,
          handler_args=(),
          handler_kwargs=dict(
              mesh_kwargs=mesh_kwargs,
              initialize_training_loop_kwargs=initialize_training_loop_kwargs,
          ),
      )
      if handler_returns:
        state, step, mesh, jitted_train_step, other_variables = handler_returns
      step += 1
    except jax.errors.JaxRuntimeError as error:
      handler_returns = elastic_manager.maybe_reshard_down(
          error=error,
          elastic_handler=elastic_handler,
          handler_args=(),
          handler_kwargs=dict(
              mesh_kwargs=mesh_kwargs,
              initialize_training_loop_kwargs=initialize_training_loop_kwargs,
          ),
      )
      if handler_returns:
        state, step, mesh, jitted_train_step, other_variables = handler_returns

  return state


def main():
  elastic_manager = manager.Manager(
      devices=jax.devices(),
      snapshot_period=10,
      snapshot_buffer_size=1,
      reshard_check_period=5,
      max_elastic_down_event_count=10,
      max_reshard_retry_count=3,
  )

  train_loop(100, elastic_manager, {}, {})

Configurare il gestore elastico

Il gestore elastico può essere configurato in diversi modi. La frequenza degli snapshot è determinata dal periodo di snapshot. Il periodo di snapshot influisce sul numero medio di passaggi persi a causa di un evento elastico. Il periodo di controllo del resharding determina la frequenza con cui il ciclo di addestramento eseguirà il polling per la disponibilità delle slice. max_elastic_down_event_count consente di impostare il numero di eventi elastici dovuti alla perdita di slice supportati dal ciclo di addestramento. max_reshard_retry_count specifica il numero di volte in cui il gestore elastico deve riprovare a eseguire il resharding. Il gestore è un oggetto singleton e deve essere creato una sola volta.

Snapshot

In base alla configurazione del gestore elastico, la funzione può creare snapshot dei dati nella memoria host che saranno disponibili per l'utilizzo da parte del gestore elastico durante un evento elastico.

Ridurre il partizionamento

Dopo aver rilevato un jax.errors.JaxRuntimeError, Pathways verificherà se l'errore è dovuto a un evento elastico a causa di una slice persa. In caso affermativo, chiamerà il gestore elastico in un ciclo fino al successo o al raggiungimento del numero massimo di tentativi. Se l'errore non è dovuto a un evento elastico, verrà generato di nuovo. I valori restituiti del gestore elastico vengono passati al chiamante.

Aumentare il partizionamento

In base alla configurazione del gestore elastico e se sono presenti slice non disponibili, Pathways verificherà se sono diventate disponibili altre slice. In caso affermativo, salverà immediatamente uno snapshot (se non è già stato creato uno snapshot preesistente per il passaggio corrente) e chiamerà il gestore elastico in un ciclo fino al successo o al raggiungimento del numero massimo di tentativi. Se si verifica il resharding, i valori restituiti del gestore elastico vengono passati al chiamante. In caso contrario, viene restituito None.

Hot-swap

Hot-swap si riferisce a una funzionalità dell'API JobSet di GKE in cui un job con priorità più elevata può acquisire rapidamente le risorse da un job con priorità inferiore, riducendo al minimo i tempi di inattività e garantendo un ripristino più rapido.

Quando viene creato un JobSet, GKE pianifica il carico di lavoro su più slice, come specificato nella configurazione di JobSet. Se si verifica un guasto hardware su una o più slice, i pod interessati vengono contrassegnati come non riusciti. Quando riprogrammi questo JobSet, se hai scelto di mantenere una slice di riserva nel cluster GKE che potrebbe essere utilizzata per un job con priorità inferiore, il sistema JobSet rimapperà il carico di lavoro della slice non riuscita del job con priorità più elevata sulla slice di riserva utilizzata dal job con priorità inferiore all'interno dello stesso cluster GKE. In genere, questa rimappatura richiede meno di un minuto.

Al riavvio di JobSet, l'hot-swap può verificarsi nelle seguenti situazioni:

  1. Modalità predefinita: se sono disponibili slice TPU di riserva e inattive all'interno dello stesso cluster, lo scheduler Kubernetes darà la priorità alla pianificazione dei job riavviati su queste slice anziché attendere la riparazione delle slice non riuscite. Ciò consente un ripristino più rapido.
  2. Carichi di lavoro eterogenei: nei cluster che eseguono più carichi di lavoro con una PriorityClass Kubernetes configurata, un JobSet riavviato può attivare un hot swap. Se l'affinità del job riavviato corrisponde alle risorse di un job con priorità inferiore, Kubernetes esegue il prerilascio del job con priorità inferiore, consentendo l'avvio immediato del job con priorità più elevata. Ad esempio, puoi configurare i pod worker di Pathways con priorità diverse utilizzando PriorityClass.

Per utilizzare le priorità nel cluster, definisci una classe di priorità, ad esempio:

kind: PriorityClass
metadata:
  name: high-prior-job
value: 2000
globalDefault: false
description: "This priority class should be used for high priority job."

Applica questo file YAML al cluster GKE:

kubectl apply -f high-prior-job.yaml

Quindi, collega la nuova classe di priorità al job worker di Pathways aggiungendo il seguente testo al podspec del pod pathways-worker.

priorityClassName: high-prior-job

Passaggi successivi