Entrenamiento resiliente con Rutas

Las rutas proporcionan beneficios de resiliencia de las siguientes maneras:

  • Suspensión y reanudación: Tolerancia ante interrupciones planificadas, como avisos de preferencia, sin necesidad de que el usuario escriba código personalizado de control de preferencia.
  • Entrenamiento elástico: Tolerancia ante fallas de hardware no planificadas sin que se bloquee el cliente, pero requiere que los usuarios escriban código de recuperación específico del modelo.

    Antes de comenzar

    Asegúrate de tener lo siguiente:

    Suspender y reanudar

    Por lo general, GKE envía un aviso de interrupción a un pod de acelerador antes de que se interrumpa. La tolerancia a la interrupción de rutas está habilitada de forma predeterminada en todas las implementaciones en la nube, y los trabajos del acelerador de rutas escuchan estos avisos.

    Cuando llega un aviso de interrupción, Pathways primero determina si la carga de trabajo actual se puede restablecer, es decir, si Pathways puede guardar y restablecer la carga de trabajo de forma transparente. Si es así, intenta suspender de forma transparente tu carga de trabajo de AA escribiendo su estado actual en el almacenamiento persistente, como Cloud Storage, antes de que GKE expulse tus trabajos del acelerador. Cuando GKE reprograma tus trabajos más adelante, Pathways reanuda tu carga de trabajo de AA leyendo su estado persistente.

    Si la carga de trabajo no se puede restablecer, Pathways cierra el trabajo del acelerador y reenvía la falla a tu trabajo si se configuró el entrenamiento elástico. Si no se configura el entrenamiento elástico, GKE reinicia toda la carga de trabajo según la política de reinicio de JobSet.

    Las cargas de trabajo típicas de AA definidas con JAX se basan en componentes de Pathways XLA sin estado que se pueden restablecer con una instantánea de memoria de ancho de banda alto (HBM). Ciertas cargas de trabajo de AA, como las que se definen con la API de Python colocada de JAX, dependen de componentes de Pathways con estado, que no se pueden restablecer.

    Entrenamiento elástico

    El entrenamiento elástico permite que tu trabajo de entrenamiento continúe incluso cuando se producen fallas de hardware. Esto se logra a través de una combinación de capacidades del sistema de Pathways y lógica de recuperación del modelo definida por el usuario:

    • Detección de fallas: Cuando se produce una falla de hardware (por ejemplo, se bloquea un trabajador de TPU), el sistema de Pathways la detecta y notifica el trabajo de entrenamiento del usuario a través de una excepción la próxima vez que se acceda a los datos que se encontraban en ese hardware. Esta notificación no falla tu carga de trabajo, sino que permite que tu código la controle y reconfigure tus recursos para continuar con el procesamiento o salir correctamente.
    • Controlador de elasticidad definido por el usuario: El código del modelo del usuario debe poder controlar esta excepción. Esto es lo que la convierte en una "recuperación específica del modelo".
      • Creación de instantáneas: El enfoque más común es guardar periódicamente instantáneas del estado de tu modelo. Cuando se produce un error, puedes cargar la instantánea más reciente para reanudar el entrenamiento.
      • Reconfiguración: Es probable que debas reconfigurar tu trabajo de entrenamiento para ajustarlo a la cantidad de segmentos disponibles. Por ejemplo, si una porción deja de funcionar, puedes reducir la cantidad de porciones activas en una hasta que haya un reemplazo disponible. Para obtener más información, consulta Elastic Handler.
      • Actualizaciones del grafo de datos o de procesamiento: Tu código debe controlar cualquier cambio en la cantidad de dispositivos disponibles para tu procesamiento recreando el grafo de procesamiento según sea necesario. Esto podría implicar volver a particionar los datos o volver a compilar tu modelo.
    • El rol de Pathways en la recuperación: Pathways proporciona las primitivas para admitir la reconfiguración definida por el usuario:
      • Reemplazo de segmentos: Si se reemplaza un segmento con errores, se puede informar al cliente cuando esté disponible el segmento nuevo. Luego, tu código puede reconfigurarse para usar este nuevo segmento.
      • Recuperación transparente: Pathways controla los detalles de nivel inferior de la recuperación, como el restablecimiento de las conexiones a las partes en buen estado del clúster.
    • Utilidades en pathwaysutils: Es un conjunto de utilidades de Pathways definidas en pathways-utils.

    Implementa un controlador elástico

    La mayor parte del código que tendrás que escribir estará en un controlador elástico definido por el usuario. Este controlador reacciona a los eventos elásticos (como cuando una porción de TPU deja de estar disponible) recreando la malla y reinicializando el bucle de entrenamiento.

    Cada carga de trabajo es única. La complejidad del controlador elástico puede aumentar con la complejidad de la carga de trabajo. Las entradas y salidas del controlador deben ser los argumentos y valores de retorno mínimos necesarios para reinicializar el bucle de entrenamiento.

    def elastic_handler(elastic_utils, *args, **kwargs):
      mesh = initialize_mesh(**kwargs["mesh_kwargs"])
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"])
    
      step, snapshot = elastic_utils.get_next_snapshot()
      state = initial_state.replace(**snapshot)
    
      return state, step, mesh, jitted_train_step, other_variables
    

    Actualiza tu bucle de entrenamiento

    Debes realizar los siguientes cambios en tu bucle de entrenamiento:

    1. Crea un administrador elástico
    2. Encapsula tu bucle de entrenamiento dentro de bloques try-except que controlen jax.errors.JaxRuntimeErrors.
    3. Dentro de tu controlador jax.errors.JaxRuntimeError, llama a maybe_reshard_down. El administrador elástico reducirá la cantidad de fragmentos si el error está relacionado con un evento elástico o, de lo contrario, lo volverá a generar.
    4. Llama a maybe_snapshot y maybe_reshard_up al final del bucle de entrenamiento
    import pathwaysutils
    from pathwaysutils.elastic import manager
    
    pathwaysutils.initialize()
    
    def initialize_mesh(**kwargs):
      ...
    
    
    def initialize_training_loop(**kwargs):
      ...
    
    
    def train_loop(
        final_step,
        elastic_manager,
        mesh_kwargs,
        initialize_training_loop_kwargs,
    ):
      mesh = initialize_mesh(**mesh_kwargs)
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **initialize_training_loop_kwargs)
    
      step = initial_step
      while step < final_step:
        try:
          state = jitted_train_step(state)
    
          elastic_manager.maybe_snapshot(step=step, snapshot=state)
          handler_returns = elastic_manager.maybe_reshard_up(
              step=step,
              snapshot=state,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
          step += 1
        except jax.errors.JaxRuntimeError as error:
          handler_returns = elastic_manager.maybe_reshard_down(
              error=error,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
    
      return state
    
    
    def main():
      elastic_manager = manager.Manager(
          devices=jax.devices(),
          snapshot_period=10,
          snapshot_buffer_size=1,
          reshard_check_period=5,
          max_elastic_down_event_count=10,
          max_reshard_retry_count=3,
      )
    
      train_loop(100, elastic_manager, {}, {})
    

    Configura el administrador elástico

    El administrador elástico se puede configurar de varias maneras. La frecuencia de las instantáneas se determina según el período de instantáneas. El período de instantánea afecta la cantidad promedio de pasos perdidos debido a un evento elástico. El período de verificación de nuevo fragmentado determina la frecuencia con la que tu bucle de entrenamiento sondeará la disponibilidad de segmentos. El parámetro max_elastic_down_event_count te permite establecer la cantidad de eventos elásticos que se deben a la pérdida de segmentación que admitirá tu bucle de entrenamiento. El parámetro max_reshard_retry_count especifica la cantidad de veces que el administrador elástico debe reintentar el cambio de fragmentación. El administrador es un objeto singleton y solo se debe crear una vez.

    Instantáneas

    Según la configuración del administrador elástico, la función puede tomar una instantánea de los datos en la memoria del host que estará disponible para que la use tu controlador elástico durante un evento elástico.

    Reduce el sharding

    Después de detectar un jax.errors.JaxRuntimeError, Pathways verificará si el error se debe a un evento elástico por una pérdida de segmento. Si es así, llamará al controlador elástico en un bucle hasta que se realice correctamente o hasta que se alcance la cantidad máxima de reintentos. Si el error no se debe a un evento elástico, se volverá a generar. Los valores de devolución del controlador elástico se pasan al llamador.

    Aumenta la cantidad de fragmentos

    Según la configuración del administrador elástico y si hay segmentos no disponibles, Rutas verificará si hay segmentos adicionales disponibles. Si es así, guardará de inmediato una instantánea (si aún no se había tomado una instantánea preexistente para el paso actual) y llamará al controlador elástico en un bucle hasta que se realice correctamente o se alcance la cantidad máxima de intentos de reintento. Si se vuelve a fragmentar, los valores de devolución del controlador elástico se pasan al llamador. De lo contrario, se muestra None.

    Intercambio en caliente

    El intercambio en caliente hace referencia a una función de la API de JobSet de GKE en la que un trabajo de mayor prioridad puede hacerse cargo rápidamente de los recursos de un trabajo de menor prioridad, lo que minimiza el tiempo de inactividad y garantiza una recuperación más rápida.

    Cuando se crea un JobSet, GKE programa la carga de trabajo en varias porciones, como se especifica en la configuración del JobSet. Si se produce una falla de hardware en una o más porciones, los Pods afectados se marcan como con errores. Cuando se reprograma este Jobset, si elegiste conservar una porción de repuesto en tu clúster de GKE que se podría utilizar para un trabajo de menor prioridad, el sistema de JobSet volverá a asignar la carga de trabajo de la porción con errores del trabajo de mayor prioridad a la porción de repuesto que utiliza el trabajo de menor prioridad dentro del mismo clúster de GKE. Por lo general, este proceso demora menos de un minuto.

    Tras el reinicio de JobSet, el intercambio en caliente puede ocurrir en las siguientes situaciones:

    1. Modo predeterminado: Si hay segmentos de TPU inactivos y de reserva disponibles en el mismo clúster, el programador de Kubernetes priorizará la programación de los trabajos reiniciados en estos segmentos en lugar de esperar a que se reparen los segmentos con errores. Esto proporciona una recuperación más rápida.
    2. Cargas de trabajo heterogéneas: En los clústeres que ejecutan varias cargas de trabajo con una PriorityClass de Kubernetes configurada, un JobSet reiniciado puede activar un intercambio en caliente. Si la afinidad del trabajo reiniciado coincide con los recursos de un trabajo de menor prioridad, Kubernetes anula el trabajo de menor prioridad, lo que permite que el trabajo de mayor prioridad comience de inmediato. Por ejemplo, puedes configurar tus Pods de trabajadores de Pathways con diferentes prioridades usando PriorityClass.

    Para usar prioridades en tu clúster, define una clase de prioridad, por ejemplo:

    kind: PriorityClass
    metadata:
      name: high-prior-job
    value: 2000
    globalDefault: false
    description: "This priority class should be used for high priority job."
    

    Aplica este archivo YAML a tu clúster de GKE:

    kubectl apply -f high-prior-job.yaml
    

    A continuación, agrega la nueva clase de prioridad al trabajo de tu trabajador de Pathways agregando el siguiente texto al podspec de tu Pod pathways-worker.

    priorityClassName: high-prior-job
    

    ¿Qué sigue?