Treinamento resiliente com programas de aprendizado

Os programas oferecem benefícios de resiliência das seguintes maneiras:

  • Suspender/retomar: tolerância em caso de interrupções planejadas, como avisos de remoção sem que o usuário precise escrever um código personalizado de processamento de remoção.
  • Treinamento elástico: tolerância a falhas de hardware não planejadas sem causar falhas no cliente, mas exigindo que os usuários escrevam um código de recuperação específico do modelo.

    Antes de começar

    Você precisa ter:

    Suspender/retomar

    Normalmente, o GKE envia um aviso de preempção para um pod de acelerador antes que ele seja removido. A tolerância à substituição de caminhos é ativada por padrão em todas as implantações na nuvem, e os jobs do acelerador de caminhos ficam atentos a esses avisos.

    Quando um aviso de remoção chega, o Pathways primeiro determina se a carga de trabalho atual pode ser restaurada, ou seja, se o Pathways pode salvar e restaurar a carga de trabalho de maneira transparente. Se sim, ele tentará suspender de maneira transparente sua carga de trabalho de ML gravando o estado atual dela em um armazenamento permanente, como o Cloud Storage, antes que o GKE remova seus jobs de acelerador. Quando o GKE reagenda seus jobs mais tarde, o Pathways retoma sua carga de trabalho de ML lendo o estado persistente.

    Se a carga de trabalho não puder ser restaurada, o programa vai encerrar o job do acelerador e encaminhar a falha para seu job se o treinamento elástico estiver configurado. Se o treinamento elástico não estiver configurado, o GKE reiniciará toda a carga de trabalho com base na política de reinicialização do JobSet.

    As cargas de trabalho típicas de ML definidas usando o JAX dependem de componentes XLA do Pathways sem estado que podem ser restaurados usando um snapshot de memória de alta largura de banda (HBM). Algumas cargas de trabalho de ML, como as definidas usando a API Python de colocação do JAX, dependem de componentes com estado do Pathways, que não podem ser restaurados.

    Treinamento elástico

    O treinamento elástico permite que o job de treinamento continue mesmo quando ocorrem falhas de hardware. Isso é feito com uma combinação de recursos do sistema de programas e lógica de recuperação de modelo definida pelo usuário:

    • Detecção de falha: quando ocorre uma falha de hardware (por exemplo, um worker de TPU falha), o sistema do Pathways detecta isso e notifica o job de treinamento do usuário com uma exceção na próxima vez que os dados localizados nesse hardware forem acessados. Essa notificação não falha na sua carga de trabalho. Ela permite que seu código processe a notificação e reconfigure os recursos para continuar o processamento ou sair normalmente.
    • Processador de elasticidade definido pelo usuário: o código do modelo do usuário precisa ser capaz de processar essa exceção. É isso que torna a recuperação específica do modelo.
      • Snapshots: a abordagem mais comum é salvar periodicamente snapshots do estado do modelo. Quando ocorre uma falha, é possível carregar o snapshot mais recente para retomar o treinamento.
      • Reconfiguração: provavelmente será necessário reconfigurar o job de treinamento para ajustar o número de intervalos disponíveis. Por exemplo, se uma fração parar de funcionar, você poderá reduzir o número de frações ativas em uma até que uma substituição esteja disponível. Para mais informações, consulte Elastic Handler.
      • Atualizações do gráfico de dados/computação: seu código precisa processar mudanças no número de dispositivos disponíveis para a computação recriando o gráfico de computação conforme necessário. Isso pode envolver a repartição de dados ou a recompilação do modelo.
    • Função dos programas na recuperação: os programas oferecem as primitivas para oferecer suporte à reconfiguração definida pelo usuário:
      • Substituição de fração: se uma fração com falha for substituída, o cliente poderá ser informado quando a nova fração estiver disponível. Em seguida, seu código pode ser reconfigurado para usar essa nova fração.
      • Recuperação transparente: os Pathways processam os detalhes de nível mais baixo da recuperação, como o restabelecimento de conexões com as partes íntegras do cluster.
    • Utilitários em pathwaysutils: um conjunto de utilitários de programas de aprendizado definidos em pathways-utils.

    Implementar um manipulador elástico

    A maior parte do código que você vai escrever estará em um manipulador elástico definido pelo usuário. Esse manipulador reage a eventos elásticos (como uma fração de TPU ficar indisponível) recriando a malha e reinicializando o loop de treinamento.

    Cada carga de trabalho é única. A complexidade do manipulador elástico pode ser dimensionada com a complexidade da carga de trabalho. As entradas e saídas do manipulador precisam ser os argumentos e valores de retorno mínimos necessários para reinicializar o loop de treinamento.

    def elastic_handler(elastic_utils, *args, **kwargs):
      mesh = initialize_mesh(**kwargs["mesh_kwargs"])
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"])
    
      step, snapshot = elastic_utils.get_next_snapshot()
      state = initial_state.replace(**snapshot)
    
      return state, step, mesh, jitted_train_step, other_variables
    

    Atualizar o loop de treinamento

    Faça as seguintes mudanças no loop de treinamento:

    1. Criar um gerenciador elástico
    2. Encapsule o loop de treinamento em blocos try-except que processam jax.errors.JaxRuntimeErrors.
    3. No manipulador jax.errors.JaxRuntimeError, chame maybe_reshard_down. O gerenciador elástico vai reduzir o refragmento se o erro estiver relacionado a um evento elástico ou vai gerar novamente.
    4. Chame maybe_snapshot e maybe_reshard_up no final do loop de treinamento
    import pathwaysutils
    from pathwaysutils.elastic import manager
    
    pathwaysutils.initialize()
    
    def initialize_mesh(**kwargs):
      ...
    
    
    def initialize_training_loop(**kwargs):
      ...
    
    
    def train_loop(
        final_step,
        elastic_manager,
        mesh_kwargs,
        initialize_training_loop_kwargs,
    ):
      mesh = initialize_mesh(**mesh_kwargs)
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **initialize_training_loop_kwargs)
    
      step = initial_step
      while step < final_step:
        try:
          state = jitted_train_step(state)
    
          elastic_manager.maybe_snapshot(step=step, snapshot=state)
          handler_returns = elastic_manager.maybe_reshard_up(
              step=step,
              snapshot=state,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
          step += 1
        except jax.errors.JaxRuntimeError as error:
          handler_returns = elastic_manager.maybe_reshard_down(
              error=error,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
    
      return state
    
    
    def main():
      elastic_manager = manager.Manager(
          devices=jax.devices(),
          snapshot_period=10,
          snapshot_buffer_size=1,
          reshard_check_period=5,
          max_elastic_down_event_count=10,
          max_reshard_retry_count=3,
      )
    
      train_loop(100, elastic_manager, {}, {})
    

    Configurar o gerenciador elástico

    O gerenciador elástico pode ser configurado de algumas maneiras diferentes. A frequência de snapshots é determinada pelo período de snapshot. O período do snapshot afeta o número médio de etapas perdidas devido a um evento elástico. O período de verificação de redistribuição determina a frequência com que o loop de treinamento vai sondar a disponibilidade de intervalos. O max_elastic_down_event_count permite definir o número de eventos elásticos devido à perda de fatia que seu loop de treinamento vai oferecer suporte. O max_reshard_retry_count especifica o número de vezes que o gerenciador elástico deve tentar fazer o refragmentação. O gerenciador é um objeto singleton e só precisa ser criado uma vez.

    Snapshots

    Com base na configuração do gerenciador elástico, a função pode criar um snapshot dos dados na memória do host, que estará disponível para uso pelo manipulador elástico durante um evento elástico.

    Reduzir o sharding

    Depois de capturar um jax.errors.JaxRuntimeError, o Pathways vai verificar se o erro é devido a um evento elástico por causa de uma fração perdida. Se for o caso, ele vai chamar o manipulador elástico em um loop até que a operação seja concluída ou o número máximo de tentativas seja atingido. Se o erro não for devido a um evento elástico, ele será gerado novamente. Os valores de retorno do manipulador elástico são transmitidos para o autor da chamada.

    Aumentar o sharding

    Com base na configuração do gerenciador elástico e se houver intervalos indisponíveis, os programas vão verificar se outros intervalos ficaram disponíveis. Se for, ele vai salvar imediatamente um snapshot (se um snapshot pré-existente para a etapa atual ainda não tiver sido feito) e chamar o manipulador elástico em um loop até que a operação seja concluída ou o número máximo de tentativas seja atingido. Se ocorrer um refragmentação, os valores de retorno do manipulador elástico serão transmitidos para o autor da chamada. Caso contrário, None será retornado.

    Hot-swap

    A troca a quente se refere a um recurso da API GKE JobSet em que um job de maior prioridade pode assumir rapidamente os recursos de um job de menor prioridade, minimizando o tempo de inatividade e garantindo uma recuperação mais rápida.

    Quando um JobSet é criado, o GKE programa a carga de trabalho em várias frações, conforme especificado na configuração do JobSet. Se ocorrer uma falha de hardware em uma ou mais frações, os pods afetados serão marcados como falha. Ao reagendar esse Jobset, se você tiver optado por manter uma fração sobressalente no cluster do GKE que poderia ser usada para um job de prioridade mais baixa, o sistema JobSet vai remapear a carga de trabalho da fração com falha do job de prioridade mais alta para a fração sobressalente usada pelo job de prioridade mais baixa no mesmo cluster do GKE. Esse remapeamento geralmente leva menos de um minuto.

    Após a reinicialização do JobSet, a troca a quente pode ocorrer nas seguintes situações:

    1. Modo padrão: se houver frações de TPU ociosas e sobressalentes disponíveis no mesmo cluster, o programador do Kubernetes vai priorizar a programação dos jobs reiniciados nessas frações em vez de esperar que as frações com falha sejam corrigidas. Isso permite uma recuperação mais rápida.
    2. Cargas de trabalho heterogêneas: em clusters que executam várias cargas de trabalho com uma PriorityClass do Kubernetes configurada, um JobSet reiniciado pode acionar uma troca a quente. Se a afinidade do job reiniciado corresponder aos recursos de um job de prioridade mais baixa, o Kubernetes vai antecipar o job de prioridade mais baixa, permitindo que o job de prioridade mais alta seja iniciado imediatamente. Por exemplo, é possível configurar os pods de worker do programa de aprendizado com diferentes prioridades usando PriorityClass.

    Para usar prioridades no cluster, defina uma classe de prioridade, por exemplo:

    kind: PriorityClass
    metadata:
      name: high-prior-job
    value: 2000
    globalDefault: false
    description: "This priority class should be used for high priority job."
    

    Aplique este YAML ao cluster do GKE:

    kubectl apply -f high-prior-job.yaml
    

    Em seguida, anexe a nova classe de prioridade ao trabalho do worker do Pathways adicionando o texto a seguir ao podspec do pod pathways-worker.

    priorityClassName: high-prior-job
    

    A seguir