Formazione resiliente con Pathways

Pathways offre vantaggi di resilienza nei seguenti modi:

  • Sospensione-Ripresa: tolleranza di fronte a interruzioni pianificate come avvisi di preemptive senza che l'utente debba scrivere codice personalizzato per la gestione del preemptive.
  • Elastic Training: tolleranza di fronte a guasti hardware non pianificati senza causare l'arresto anomalo del client, ma richiedendo agli utenti di scrivere codice di ripristino specifico per il modello.

    Prima di iniziare

    Assicurati di avere:

    Sospendi-riprendi

    In genere, GKE invia una notifica di prerilascio a un pod dell'acceleratore, prima che il pod venga prerilasciato. La tolleranza al preempting dei percorsi è abilitata per impostazione predefinita in tutti i deployment cloud e i job dell'acceleratore Pathways sono in ascolto di queste notifiche.

    Quando arriva una notifica di preemption, Pathways determina innanzitutto se il workload attuale è ripristinabile, ovvero se può salvare e ripristinare il workload in modo trasparente. In questo caso, tenta di sospendere in modo trasparente il carico di lavoro ML scrivendo il suo stato attuale in uno spazio di archiviazione permanente come Cloud Storage prima che GKE espella i job dell'acceleratore. Quando GKE riprogramma i tuoi job in un secondo momento, Pathways riprende il tuo carico di lavoro ML leggendo il suo stato persistente.

    Se il workload non è ripristinabile, Pathways arresta il job dell'acceleratore e inoltra l'errore al tuo job se è configurato Elastic training. Se l'addestramento elastico non è configurato, GKE riavvia l'intero carico di lavoro in base al criterio di riavvio di JobSet.

    I tipici workload ML definiti utilizzando JAX si basano su componenti Pathways XLA senza stato che possono essere ripristinati utilizzando uno snapshot della memoria ad alta larghezza di banda (HBM). Alcuni carichi di lavoro ML, come quelli definiti utilizzando l'API Python collocate JAX, si basano su componenti stateful di Pathways, che non sono ripristinabili.

    Allenamento con elastici

    L'addestramento elastico consente al job di addestramento di continuare anche in caso di errori hardware. Ciò si ottiene grazie a una combinazione di funzionalità del sistema Pathways e logica di recupero del modello definita dall'utente:

    • Rilevamento di errori: quando si verifica un errore hardware (ad esempio, un worker TPU si arresta in modo anomalo), il sistema Pathways lo rileva e invia una notifica al job di addestramento dell'utente tramite un'eccezione al successivo accesso ai dati che si trovavano sull'hardware. Questa notifica non causa l'arresto anomalo del carico di lavoro, ma consente al codice di gestirla e riconfigurare le risorse per continuare l'elaborazione o uscire in modo controllato.
    • Gestore di elasticità definito dall'utente: il codice del modello dell'utente deve essere in grado di gestire questa eccezione. Per questo motivo si parla di "recupero specifico del modello".
      • Snapshot: l'approccio più comune è salvare periodicamente snapshot dello stato del modello. Quando si verifica un errore, puoi caricare lo snapshot più recente per riprendere l'addestramento.
      • Riconfigurazione: probabilmente dovrai riconfigurare il job di addestramento per adattarlo al numero di slice disponibili. Ad esempio, se una sezione smette di funzionare, puoi ridurre il numero di sezioni attive di una unità finché non è disponibile una sostituzione. Per saperne di più, consulta Elastic Handler.
      • Aggiornamenti del grafico di dati/calcoli: il codice deve gestire eventuali modifiche al numero di dispositivi disponibili per il calcolo ricreando il grafico di calcolo in base alle necessità. Ciò potrebbe comportare la ripartizione dei dati o la ricompilazione del modello.
    • Ruolo di Pathways nel recupero: Pathways fornisce le primitive per supportare la riconfigurazione definita dall'utente:
      • Sostituzione della sezione: se una sezione non riuscita viene sostituita, il client può essere informato una volta che la nuova sezione è disponibile. Il codice può quindi essere riconfigurato per utilizzare questa nuova sezione.
      • Recupero trasparente: Pathways gestisce i dettagli di livello inferiore del recupero, come il ristabilimento delle connessioni alle parti integre del cluster.
    • Utilità in pathwaysutils: un insieme di utilità di Pathways definite in pathways-utils.

    Implementa un gestore elastico

    La maggior parte del codice che dovrai scrivere si troverà in un gestore elastico definito dall'utente. Questo gestore reagisce agli eventi elastici (ad esempio, una sezione TPU non è più disponibile) ricreando la mesh e reinizializzando il ciclo di addestramento.

    Ogni carico di lavoro è unico. La complessità del gestore elastico può essere scalata in base alla complessità del workload. Gli input e gli output del gestore devono essere gli argomenti minimi e i valori restituiti necessari per reinizializzare il ciclo di addestramento.

    def elastic_handler(elastic_utils, *args, **kwargs):
      mesh = initialize_mesh(**kwargs["mesh_kwargs"])
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"])
    
      step, snapshot = elastic_utils.get_next_snapshot()
      state = initial_state.replace(**snapshot)
    
      return state, step, mesh, jitted_train_step, other_variables
    

    Aggiornare il loop di addestramento

    Devi apportare le seguenti modifiche al ciclo di addestramento:

    1. Crea un gestore elastico
    2. Racchiudi il ciclo di addestramento all'interno di blocchi try-except che gestiscono jax.errors.JaxRuntimeError
    3. All'interno del gestore jax.errors.JaxRuntimeError, chiama maybe_reshard_down. Elastic Manager eseguirà il resharding verso il basso se l'errore è correlato a un evento elastico o lo riproporrà.
    4. Chiama maybe_snapshot e maybe_reshard_up alla fine del ciclo di addestramento
    import pathwaysutils
    from pathwaysutils.elastic import manager
    
    pathwaysutils.initialize()
    
    def initialize_mesh(**kwargs):
      ...
    
    
    def initialize_training_loop(**kwargs):
      ...
    
    
    def train_loop(
        final_step,
        elastic_manager,
        mesh_kwargs,
        initialize_training_loop_kwargs,
    ):
      mesh = initialize_mesh(**mesh_kwargs)
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **initialize_training_loop_kwargs)
    
      step = initial_step
      while step < final_step:
        try:
          state = jitted_train_step(state)
    
          elastic_manager.maybe_snapshot(step=step, snapshot=state)
          handler_returns = elastic_manager.maybe_reshard_up(
              step=step,
              snapshot=state,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
          step += 1
        except jax.errors.JaxRuntimeError as error:
          handler_returns = elastic_manager.maybe_reshard_down(
              error=error,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
    
      return state
    
    
    def main():
      elastic_manager = manager.Manager(
          devices=jax.devices(),
          snapshot_period=10,
          snapshot_buffer_size=1,
          reshard_check_period=5,
          max_elastic_down_event_count=10,
          max_reshard_retry_count=3,
      )
    
      train_loop(100, elastic_manager, {}, {})
    

    Configura Elastic Manager

    Il gestore elastico può essere configurato in vari modi. La frequenza di creazione degli snapshot è determinata dal periodo di snapshot. Il periodo dello snapshot influisce sul numero medio di passi persi a causa di un evento elastico. Il periodo di controllo del resharding determina la frequenza con cui il ciclo di addestramento esegue il polling per verificare la disponibilità delle sezioni. max_elastic_down_event_count consente di impostare il numero di eventi elastici dovuti alla perdita di slice supportati dal ciclo di addestramento. max_reshard_retry_count specifica il numero di tentativi di nuovo partizionamento che deve eseguire Elastic Manager. Il gestore è un oggetto singleton e deve essere creato una sola volta.

    Snapshot

    In base alla configurazione del gestore elastico, la funzione potrebbe creare uno snapshot dei dati nella memoria host che sarà disponibile per l'utilizzo da parte del gestore elastico durante un evento elastico.

    Ridurre lo sharding

    Dopo aver rilevato un jax.errors.JaxRuntimeError, Pathways controllerà se l'errore è dovuto a un evento elastico a causa di una sezione persa. In questo caso, chiamerà il gestore elastico in un ciclo fino al successo o al raggiungimento del numero massimo di tentativi. Se l'errore non è dovuto a un evento elastico, verrà generato di nuovo. I valori restituiti del gestore elastico vengono passati al chiamante.

    Aumentare lo sharding

    In base alla configurazione di Elastic Manager e se sono presenti slice non disponibili, Pathways verificherà se sono disponibili slice aggiuntivi. In questo caso, verrà immediatamente salvato uno snapshot (se non è già stato acquisito uno snapshot preesistente per il passaggio corrente) e verrà chiamato il gestore elastico in un ciclo fino al completamento dell'operazione o al raggiungimento del numero massimo di tentativi. Se si verifica il re-sharding, i valori restituiti del gestore elastico vengono passati al chiamante. In caso contrario, viene restituito None.

    Hot-swap

    Lo scambio a caldo si riferisce a una funzionalità dell'API GKE JobSet in cui un job con priorità più alta può assumere rapidamente il controllo delle risorse di un job con priorità inferiore, riducendo al minimo i tempi di inattività e garantendo un ripristino più rapido.

    Quando viene creato un JobSet, GKE pianifica il workload in più slice, come specificato nella configurazione del JobSet. Se si verifica un errore hardware su una o più sezioni, i pod interessati vengono contrassegnati come non riusciti. Quando riprogrammi questo JobSet, se hai scelto di mantenere una slice di riserva nel cluster GKE che potrebbe essere utilizzata per un job a priorità inferiore, il sistema JobSet rimappa il workload della slice non riuscita del job a priorità più alta sulla slice di riserva utilizzata dal job a priorità inferiore all'interno dello stesso cluster GKE. Questa rimappatura in genere richiede meno di un minuto.

    Al riavvio di JobSet, lo scambio a caldo può verificarsi nelle seguenti situazioni:

    1. Modalità predefinita: se sono disponibili sezioni TPU di riserva e inattive all'interno dello stesso cluster, lo scheduler Kubernetes darà la priorità alla pianificazione dei job riavviati su queste sezioni anziché attendere la riparazione delle sezioni non riuscite. Ciò consente un recupero più rapido.
    2. Workload eterogenei: nei cluster che eseguono più workload con una PriorityClass Kubernetes configurata, un JobSet riavviato può attivare uno scambio a caldo. Se l'affinità del job riavviato corrisponde alle risorse di un job con priorità inferiore, Kubernetes esegue la preemption del job con priorità inferiore, consentendo l'avvio immediato del job con priorità più alta. Ad esempio, puoi configurare i pod worker di Pathways con priorità diverse utilizzando PriorityClass.

    Per utilizzare le priorità nel cluster, definisci una classe di priorità, ad esempio:

    kind: PriorityClass
    metadata:
      name: high-prior-job
    value: 2000
    globalDefault: false
    description: "This priority class should be used for high priority job."
    

    Applica questo file YAML al tuo cluster GKE:

    kubectl apply -f high-prior-job.yaml
    

    Successivamente, collega la nuova classe di priorità al job del worker Pathways aggiungendo il seguente testo al podspec del tuo pod pathways-worker.

    priorityClassName: high-prior-job
    

    Passaggi successivi