Pathways offre des avantages en termes de résilience de plusieurs façons :
- Suspendre/Reprendre : tolérance face aux interruptions planifiées, comme les notifications de préemption, sans que l'utilisateur ait besoin d'écrire de code personnalisé de gestion de la préemption.
- Entraînement élastique : tolérance face aux défaillances matérielles non planifiées sans provoquer le plantage du client, mais nécessitant que les utilisateurs écrivent un code de récupération spécifique au modèle.
Avant de commencer
Vérifiez que vous disposez bien des éléments suivants :
Suspendre/Reprendre
En règle générale, GKE envoie une notification de préemption à un pod d'accélérateur avant que le pod ne soit préempté. La tolérance de préemption de Pathways est activée par défaut sur tous les déploiements cloud, et les tâches d'accélérateur Pathways écoutent ces notifications.
Lorsqu'une notification de préemption arrive, Pathways détermine d'abord si la charge de travail actuelle est restaurable, c'est-à-dire si Pathways peut enregistrer et restaurer la charge de travail de manière transparente. Si c'est le cas, il tente de suspendre de manière transparente votre charge de travail de ML en écrivant son état actuel dans un stockage persistant tel que Cloud Storage avant que GKE n'évince vos tâches d'accélérateur. Lorsque GKE reprogramme vos tâches ultérieurement, Pathways reprend votre charge de travail de ML en relisant son état persistant.
Si la charge de travail n'est pas restaurable, Pathways arrête la tâche d'accélérateur et transmet l'échec à votre tâche si l'entraînement élastique est configuré. Si l'entraînement élastique n'est pas configuré, GKE redémarre l'intégralité de la charge de travail en fonction de la règle de redémarrage JobSet.
Les charges de travail de ML typiques définies à l'aide de JAX s'appuient sur des composants Pathways XLA sans état qui peuvent être restaurés à l'aide d'un instantané de mémoire à bande passante élevée (HBM). Certaines charges de travail de ML , telles que celles définies à l'aide de l'API Python colocalisée JAX , s'appuient sur des composants Pathways avec état. Elles ne sont pas restaurables.
Entraînement élastique
L'entraînement élastique permet à votre tâche d'entraînement de se poursuivre même en cas de défaillance matérielle. Pour ce faire, il combine les capacités du système Pathways et la logique de récupération du modèle définie par l'utilisateur :
- Détection des défaillances : en cas de défaillance matérielle (par exemple, en cas de plantage d'un nœud de calcul TPU ), le système Pathways la détecte et en informe la tâche d'entraînement de l'utilisateur via une exception lors du prochain accès aux données situées sur ce matériel. Cette notification ne plante pas votre charge de travail. Elle permet à votre code de gérer la notification et de reconfigurer vos ressources pour poursuivre le traitement ou quitter correctement.
- Gestionnaire d'élasticité défini par l'utilisateur : le code du modèle de l'utilisateur doit être en mesure de
gérer cette exception. C'est ce qui en fait une "récupération spécifique au modèle".
- Instantanés : l'approche la plus courante consiste à enregistrer régulièrement des instantanés de l'état de votre modèle. En cas de défaillance, vous pouvez charger l'instantané le plus récent pour reprendre l'entraînement.
- Reconfiguration : vous devrez probablement reconfigurer votre tâche d'entraînement pour l'adapter au nombre de tranches disponibles. Par exemple, si une tranche cesse de fonctionner, vous pouvez réduire le nombre de tranches actives d'une unité jusqu'à ce qu'un remplacement soit disponible. Pour en savoir plus, consultez la section Gestionnaire élastique.
- Mises à jour du graphique de données/de calcul : votre code doit gérer toutes les modifications du nombre d’appareils disponibles pour votre calcul en recréant le graphique de calcul si nécessaire. Cela peut impliquer de repartitionner les données ou de recompiler votre modèle.
- Rôle de Pathways dans la récupération : Pathways fournit les primitives pour prendre en charge
la reconfiguration définie par l'utilisateur :
- Remplacement de tranche : si une tranche défaillante est remplacée, le client peut être informé une fois la nouvelle tranche disponible. Votre code peut ensuite être reconfiguré pour utiliser cette nouvelle tranche.
- Récupération transparente : Pathways gère les détails de récupération de niveau inférieur, comme le rétablissement des connexions aux parties saines du cluster.
- Utilitaires dans pathwaysutils : ensemble d'utilitaires Pathways définis dans pathways-utils.
Implémenter un gestionnaire élastique
La majeure partie du code que vous devrez écrire se trouvera dans un gestionnaire élastique défini par l'utilisateur. Ce gestionnaire réagit aux événements élastiques (par exemple, lorsqu'une tranche TPU devient indisponible) en recréant le maillage et en réinitialisant la boucle d'entraînement.
Chaque charge de travail est unique. La complexité du gestionnaire élastique peut évoluer en fonction de la complexité de la charge de travail. Les entrées et les sorties du gestionnaire doivent être les arguments et les valeurs de retour minimales nécessaires pour réinitialiser la boucle d'entraînement.
def elastic_handler(elastic_utils, *args, **kwargs):
mesh = initialize_mesh(**kwargs["mesh_kwargs"])
initial_state, initial_step, jitted_train_step, other_variables =
initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"])
step, snapshot = elastic_utils.get_next_snapshot()
state = initial_state.replace(**snapshot)
return state, step, mesh, jitted_train_step, other_variables
Mettre à jour votre boucle d'entraînement
Vous devez apporter les modifications suivantes à votre boucle d'entraînement :
- Créer un gestionnaire élastique
- Encapsuler votre boucle d'entraînement dans des blocs try-except qui gèrent les
jax.errors.JaxRuntimeError - Dans votre gestionnaire
jax.errors.JaxRuntimeError, appelezmaybe_reshard_down. Le gestionnaire élastique effectue un resharding si l'erreur est liée à un événement élastique ou la relance. - Appelez
maybe_snapshotetmaybe_reshard_upà la fin de la boucle d'entraînement.
import pathwaysutils
from pathwaysutils.elastic import manager
pathwaysutils.initialize()
def initialize_mesh(**kwargs):
...
def initialize_training_loop(**kwargs):
...
def train_loop(
final_step,
elastic_manager,
mesh_kwargs,
initialize_training_loop_kwargs,
):
mesh = initialize_mesh(**mesh_kwargs)
initial_state, initial_step, jitted_train_step, other_variables =
initialize_training_loop(mesh, **initialize_training_loop_kwargs)
step = initial_step
while step < final_step:
try:
state = jitted_train_step(state)
elastic_manager.maybe_snapshot(step=step, snapshot=state)
handler_returns = elastic_manager.maybe_reshard_up(
step=step,
snapshot=state,
elastic_handler=elastic_handler,
handler_args=(),
handler_kwargs=dict(
mesh_kwargs=mesh_kwargs,
initialize_training_loop_kwargs=initialize_training_loop_kwargs,
),
)
if handler_returns:
state, step, mesh, jitted_train_step, other_variables = handler_returns
step += 1
except jax.errors.JaxRuntimeError as error:
handler_returns = elastic_manager.maybe_reshard_down(
error=error,
elastic_handler=elastic_handler,
handler_args=(),
handler_kwargs=dict(
mesh_kwargs=mesh_kwargs,
initialize_training_loop_kwargs=initialize_training_loop_kwargs,
),
)
if handler_returns:
state, step, mesh, jitted_train_step, other_variables = handler_returns
return state
def main():
elastic_manager = manager.Manager(
devices=jax.devices(),
snapshot_period=10,
snapshot_buffer_size=1,
reshard_check_period=5,
max_elastic_down_event_count=10,
max_reshard_retry_count=3,
)
train_loop(100, elastic_manager, {}, {})
Configurer le gestionnaire élastique
Le gestionnaire élastique peut être configuré de plusieurs façons. La fréquence des instantanés est déterminée par la période d'instantané. La période d'instantané affecte le nombre moyen d'étapes perdues en raison d'un événement élastique. La période de vérification du resharding détermine la fréquence à laquelle votre boucle d'entraînement interroge la disponibilité des tranches.
Le paramètre max_elastic_down_event_count vous permet de définir le nombre d'événements élastiques dus à la perte de tranches que votre boucle d'entraînement prendra en charge. Le paramètre max_reshard_retry_count spécifie le nombre de fois où le gestionnaire élastique doit réessayer le resharding. Le gestionnaire est un objet singleton et ne doit être créé qu'une seule fois.
Instantanés
En fonction de la configuration du gestionnaire élastique, la fonction peut créer un instantané des données dans la mémoire de l'hôte, qui sera disponible pour votre gestionnaire élastique lors d'un événement élastique.
Réduire le sharding
Après avoir intercepté une jax.errors.JaxRuntimeError, Pathways vérifie si l'erreur est due à un événement élastique dû à une tranche perdue. Si c'est le cas, il appelle le gestionnaire élastique dans une boucle jusqu'à ce qu'il réussisse ou qu'il atteigne le nombre maximal de tentatives. Si l'erreur n'est pas due à un événement élastique, elle est relancée. Les valeurs renvoyées par le gestionnaire élastique sont transmises à l'appelant.
Augmenter le sharding
En fonction de la configuration du gestionnaire élastique et s'il existe des tranches indisponibles, Pathways vérifie si des tranches supplémentaires sont devenues disponibles. Si c'est le cas, il enregistre immédiatement un instantané (si un instantané préexistant pour l'étape actuelle n'a pas déjà été pris) et appelle le gestionnaire élastique dans une boucle jusqu'à ce qu'il réussisse ou qu'il atteigne le nombre maximal de tentatives. Si un resharding se produit, les valeurs renvoyées par le gestionnaire élastique sont transmises à l'appelant. Sinon, None est renvoyé.
Échange à chaud
L'échange à chaud fait référence à une fonctionnalité de l'API GKE JobSet dans laquelle une tâche de priorité plus élevée peut rapidement prendre le contrôle des ressources d'une tâche de priorité inférieure, ce qui réduit les temps d'arrêt et assure une récupération plus rapide.
Lorsqu'un JobSet est créé, GKE planifie la charge de travail sur plusieurs tranches, comme spécifié dans la configuration JobSet. En cas de défaillance matérielle sur une ou plusieurs tranches, les pods concernés sont marqués comme ayant échoué. Lors de la reprogrammation de ce JobSet, si vous avez choisi de conserver une tranche de rechange dans votre cluster GKE qui pourrait être utilisée pour une tâche de priorité inférieure, le système JobSet remappera la charge de travail de la tranche défaillante de la tâche de priorité supérieure sur la tranche de rechange utilisée par la tâche de priorité inférieure dans le même cluster GKE. Ce remapping prend généralement moins d'une minute.
Lors du redémarrage de JobSet, l'échange à chaud peut se produire dans les situations suivantes :
- Mode par défaut : si des tranches TPU de rechange et inactives sont disponibles dans le même cluster, le planificateur Kubernetes planifie en priorité les tâches redémarrées sur ces tranches plutôt que d'attendre la réparation des tranches défaillantes. Cela permet une récupération plus rapide.
- Charges de travail hétérogènes : dans les clusters exécutant plusieurs charges de travail avec
une PriorityClass Kubernetes configurée, un JobSet redémarré peut déclencher un échange à chaud. Si l'affinité de la tâche redémarrée correspond aux ressources d'une tâche de priorité inférieure, Kubernetes préempte la tâche de priorité inférieure, ce qui permet à la tâche de priorité supérieure de démarrer immédiatement. Par exemple, vous pouvez configurer vos pods de nœuds de calcul Pathways avec différentes priorités à l'aide de
PriorityClass.
Pour utiliser des priorités dans votre cluster, définissez une classe de priorité, par exemple :
kind: PriorityClass
metadata:
name: high-prior-job
value: 2000
globalDefault: false
description: "This priority class should be used for high priority job."
Appliquez ce code YAML à votre cluster GKE :
kubectl apply -f high-prior-job.yaml
Ensuite, associez la nouvelle classe de priorité à votre tâche de nœud de calcul Pathways en ajoutant le texte suivant au podspec de votre pod pathways-worker.
priorityClassName: high-prior-job
Étape suivante
- Créer un cluster GKE avec Pathways
- Exécuter une charge de travail par lot avec Pathways
- Exécuter une charge de travail interactive avec Pathways
- Effectuer une inférence multihôte à l'aide de Pathways
- Transférer des charges de travail JAX vers Pathways
- Résoudre les problèmes liés à Pathways dans le cloud