Pathways offre vantaggi di resilienza nei seguenti modi:
- Sospensione-Ripresa: tolleranza di fronte a interruzioni pianificate come avvisi di preemptive senza che l'utente debba scrivere codice personalizzato per la gestione del preemptive.
- Elastic Training: tolleranza di fronte a guasti hardware non pianificati
senza causare l'arresto anomalo del client, ma richiedendo agli utenti di scrivere codice di ripristino
specifico per il modello.
Prima di iniziare
Assicurati di avere:
- XPK installato
- Strumenti Kubernetes installati
- Installazione di gcloud CLI
- Abilitato l'API TPU
- Abilitato l'API GKE
Sospendi-riprendi
In genere, GKE invia una notifica di prerilascio a un pod dell'acceleratore, prima che il pod venga prerilasciato. La tolleranza al preempting dei percorsi è abilitata per impostazione predefinita in tutti i deployment cloud e i job dell'acceleratore Pathways sono in ascolto di queste notifiche.
Quando arriva una notifica di preemption, Pathways determina innanzitutto se il workload attuale è ripristinabile, ovvero se può salvare e ripristinare il workload in modo trasparente. In questo caso, tenta di sospendere in modo trasparente il carico di lavoro ML scrivendo il suo stato attuale in uno spazio di archiviazione permanente come Cloud Storage prima che GKE espella i job dell'acceleratore. Quando GKE riprogramma i tuoi job in un secondo momento, Pathways riprende il tuo carico di lavoro ML leggendo il suo stato persistente.
Se il workload non è ripristinabile, Pathways arresta il job dell'acceleratore e inoltra l'errore al tuo job se è configurato Elastic training. Se l'addestramento elastico non è configurato, GKE riavvia l'intero carico di lavoro in base al criterio di riavvio di JobSet.
I tipici workload ML definiti utilizzando JAX si basano su componenti Pathways XLA senza stato che possono essere ripristinati utilizzando uno snapshot della memoria ad alta larghezza di banda (HBM). Alcuni carichi di lavoro ML, come quelli definiti utilizzando l'API Python collocate JAX, si basano su componenti stateful di Pathways, che non sono ripristinabili.
Allenamento con elastici
L'addestramento elastico consente al job di addestramento di continuare anche in caso di errori hardware. Ciò si ottiene grazie a una combinazione di funzionalità del sistema Pathways e logica di recupero del modello definita dall'utente:
- Rilevamento di errori: quando si verifica un errore hardware (ad esempio, un worker TPU si arresta in modo anomalo), il sistema Pathways lo rileva e invia una notifica al job di addestramento dell'utente tramite un'eccezione al successivo accesso ai dati che si trovavano sull'hardware. Questa notifica non causa l'arresto anomalo del carico di lavoro, ma consente al codice di gestirla e riconfigurare le risorse per continuare l'elaborazione o uscire in modo controllato.
- Gestore di elasticità definito dall'utente: il codice del modello dell'utente deve essere in grado di
gestire questa eccezione. Per questo motivo si parla di "recupero specifico del modello".
- Snapshot: l'approccio più comune è salvare periodicamente snapshot dello stato del modello. Quando si verifica un errore, puoi caricare lo snapshot più recente per riprendere l'addestramento.
- Riconfigurazione: probabilmente dovrai riconfigurare il job di addestramento per adattarlo al numero di slice disponibili. Ad esempio, se una sezione smette di funzionare, puoi ridurre il numero di sezioni attive di una unità finché non è disponibile una sostituzione. Per saperne di più, consulta Elastic Handler.
- Aggiornamenti del grafico di dati/calcoli: il codice deve gestire eventuali modifiche al numero di dispositivi disponibili per il calcolo ricreando il grafico di calcolo in base alle necessità. Ciò potrebbe comportare la ripartizione dei dati o la ricompilazione del modello.
- Ruolo di Pathways nel recupero: Pathways fornisce le primitive per supportare
la riconfigurazione definita dall'utente:
- Sostituzione della sezione: se una sezione non riuscita viene sostituita, il client può essere informato una volta che la nuova sezione è disponibile. Il codice può quindi essere riconfigurato per utilizzare questa nuova sezione.
- Recupero trasparente: Pathways gestisce i dettagli di livello inferiore del recupero, come il ristabilimento delle connessioni alle parti integre del cluster.
- Utilità in pathwaysutils: un insieme di utilità di Pathways definite in pathways-utils.
Implementa un gestore elastico
La maggior parte del codice che dovrai scrivere si troverà in un gestore elastico definito dall'utente. Questo gestore reagisce agli eventi elastici (ad esempio, una sezione TPU non è più disponibile) ricreando la mesh e reinizializzando il ciclo di addestramento.
Ogni carico di lavoro è unico. La complessità del gestore elastico può essere scalata in base alla complessità del workload. Gli input e gli output del gestore devono essere gli argomenti minimi e i valori restituiti necessari per reinizializzare il ciclo di addestramento.
def elastic_handler(elastic_utils, *args, **kwargs): mesh = initialize_mesh(**kwargs["mesh_kwargs"]) initial_state, initial_step, jitted_train_step, other_variables = initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"]) step, snapshot = elastic_utils.get_next_snapshot() state = initial_state.replace(**snapshot) return state, step, mesh, jitted_train_step, other_variablesAggiornare il loop di addestramento
Devi apportare le seguenti modifiche al ciclo di addestramento:
- Crea un gestore elastico
- Racchiudi il ciclo di addestramento all'interno di blocchi try-except che gestiscono
jax.errors.JaxRuntimeError - All'interno del gestore
jax.errors.JaxRuntimeError, chiamamaybe_reshard_down. Elastic Manager eseguirà il resharding verso il basso se l'errore è correlato a un evento elastico o lo riproporrà. - Chiama
maybe_snapshotemaybe_reshard_upalla fine del ciclo di addestramento
import pathwaysutils from pathwaysutils.elastic import manager pathwaysutils.initialize() def initialize_mesh(**kwargs): ... def initialize_training_loop(**kwargs): ... def train_loop( final_step, elastic_manager, mesh_kwargs, initialize_training_loop_kwargs, ): mesh = initialize_mesh(**mesh_kwargs) initial_state, initial_step, jitted_train_step, other_variables = initialize_training_loop(mesh, **initialize_training_loop_kwargs) step = initial_step while step < final_step: try: state = jitted_train_step(state) elastic_manager.maybe_snapshot(step=step, snapshot=state) handler_returns = elastic_manager.maybe_reshard_up( step=step, snapshot=state, elastic_handler=elastic_handler, handler_args=(), handler_kwargs=dict( mesh_kwargs=mesh_kwargs, initialize_training_loop_kwargs=initialize_training_loop_kwargs, ), ) if handler_returns: state, step, mesh, jitted_train_step, other_variables = handler_returns step += 1 except jax.errors.JaxRuntimeError as error: handler_returns = elastic_manager.maybe_reshard_down( error=error, elastic_handler=elastic_handler, handler_args=(), handler_kwargs=dict( mesh_kwargs=mesh_kwargs, initialize_training_loop_kwargs=initialize_training_loop_kwargs, ), ) if handler_returns: state, step, mesh, jitted_train_step, other_variables = handler_returns return state def main(): elastic_manager = manager.Manager( devices=jax.devices(), snapshot_period=10, snapshot_buffer_size=1, reshard_check_period=5, max_elastic_down_event_count=10, max_reshard_retry_count=3, ) train_loop(100, elastic_manager, {}, {})Configura Elastic Manager
Il gestore elastico può essere configurato in vari modi. La frequenza di creazione degli snapshot è determinata dal periodo di snapshot. Il periodo dello snapshot influisce sul numero medio di passi persi a causa di un evento elastico. Il periodo di controllo del resharding determina la frequenza con cui il ciclo di addestramento esegue il polling per verificare la disponibilità delle sezioni.
max_elastic_down_event_countconsente di impostare il numero di eventi elastici dovuti alla perdita di slice supportati dal ciclo di addestramento.max_reshard_retry_countspecifica il numero di tentativi di nuovo partizionamento che deve eseguire Elastic Manager. Il gestore è un oggetto singleton e deve essere creato una sola volta.Snapshot
In base alla configurazione del gestore elastico, la funzione potrebbe creare uno snapshot dei dati nella memoria host che sarà disponibile per l'utilizzo da parte del gestore elastico durante un evento elastico.
Ridurre lo sharding
Dopo aver rilevato un
jax.errors.JaxRuntimeError, Pathways controllerà se l'errore è dovuto a un evento elastico a causa di una sezione persa. In questo caso, chiamerà il gestore elastico in un ciclo fino al successo o al raggiungimento del numero massimo di tentativi. Se l'errore non è dovuto a un evento elastico, verrà generato di nuovo. I valori restituiti del gestore elastico vengono passati al chiamante.Aumentare lo sharding
In base alla configurazione di Elastic Manager e se sono presenti slice non disponibili, Pathways verificherà se sono disponibili slice aggiuntivi. In questo caso, verrà immediatamente salvato uno snapshot (se non è già stato acquisito uno snapshot preesistente per il passaggio corrente) e verrà chiamato il gestore elastico in un ciclo fino al completamento dell'operazione o al raggiungimento del numero massimo di tentativi. Se si verifica il re-sharding, i valori restituiti del gestore elastico vengono passati al chiamante. In caso contrario, viene restituito
None.Hot-swap
Lo scambio a caldo si riferisce a una funzionalità dell'API GKE JobSet in cui un job con priorità più alta può assumere rapidamente il controllo delle risorse di un job con priorità inferiore, riducendo al minimo i tempi di inattività e garantendo un ripristino più rapido.
Quando viene creato un JobSet, GKE pianifica il workload in più slice, come specificato nella configurazione del JobSet. Se si verifica un errore hardware su una o più sezioni, i pod interessati vengono contrassegnati come non riusciti. Quando riprogrammi questo JobSet, se hai scelto di mantenere una slice di riserva nel cluster GKE che potrebbe essere utilizzata per un job a priorità inferiore, il sistema JobSet rimappa il workload della slice non riuscita del job a priorità più alta sulla slice di riserva utilizzata dal job a priorità inferiore all'interno dello stesso cluster GKE. Questa rimappatura in genere richiede meno di un minuto.
Al riavvio di JobSet, lo scambio a caldo può verificarsi nelle seguenti situazioni:
- Modalità predefinita: se sono disponibili sezioni TPU di riserva e inattive all'interno dello stesso cluster, lo scheduler Kubernetes darà la priorità alla pianificazione dei job riavviati su queste sezioni anziché attendere la riparazione delle sezioni non riuscite. Ciò consente un recupero più rapido.
- Workload eterogenei: nei cluster che eseguono più workload con una PriorityClass Kubernetes configurata, un JobSet riavviato può attivare uno scambio a caldo. Se l'affinità del job riavviato corrisponde alle risorse di un job con priorità inferiore, Kubernetes esegue la preemption del job con priorità inferiore, consentendo l'avvio immediato del job con priorità più alta. Ad esempio, puoi configurare i pod worker di
Pathways con priorità diverse utilizzando
PriorityClass.
Per utilizzare le priorità nel cluster, definisci una classe di priorità, ad esempio:
kind: PriorityClass metadata: name: high-prior-job value: 2000 globalDefault: false description: "This priority class should be used for high priority job."Applica questo file YAML al tuo cluster GKE:
kubectl apply -f high-prior-job.yamlSuccessivamente, collega la nuova classe di priorità al job del worker Pathways aggiungendo il seguente testo al podspec del tuo pod
pathways-worker.priorityClassName: high-prior-jobPassaggi successivi