Treinamento resiliente com programas de aprendizado

O Pathways oferece benefícios de resiliência das seguintes maneiras:

  • Suspender/retomar: tolerância em caso de interrupções planejadas, como avisos de preempção, sem que o usuário precise escrever um código de tratamento de preempção personalizado.
  • Treinamento elástico: tolerância em caso de falhas de hardware não planejadas sem causar falhas no cliente, mas exigindo que os usuários escrevam um código de recuperação específico do modelo.

Antes de começar

Você precisa ter:

Suspender/retomar

Normalmente, o GKE envia um aviso de preempção para um pod de acelerador antes que ele seja preemptivo. A tolerância de preempção do Pathways é ativada por padrão em todas as implantações na nuvem, e os jobs do acelerador do Pathways ficam atentos a esses avisos.

Quando um aviso de preempção chega, o Pathways primeiro determina se a carga de trabalho atual pode ser restaurada, ou seja, se o Pathways pode salvar e restaurar a carga de trabalho de forma transparente. Se for o caso, ele tenta suspender de forma transparente a carga de trabalho de ML gravando o estado atual no armazenamento permanente, como o Cloud Storage, antes que o GKE remova os jobs do acelerador. Quando o GKE reprograma seus jobs mais tarde, o Pathways retoma a carga de trabalho de ML lendo o estado persistido.

Se a carga de trabalho não puder ser restaurada, o Pathways vai encerrar o job do acelerador e encaminhar a falha para o job se o treinamento elástico estiver configurado. Se o treinamento elástico não estiver configurado, o GKE vai reiniciar toda a carga de trabalho com base na política de reinicialização do JobSet.

As cargas de trabalho de ML típicas definidas usando o JAX dependem de componentes XLA do Pathways sem estado que podem ser restaurados usando um snapshot de memória de alta largura de banda (HBM, na sigla em inglês). Algumas cargas de trabalho de ML , como as definidas usando a API Python colocada do JAX , dependem de componentes do Pathways com estado. Elas não podem ser restauradas.

Treinamento elástico

O treinamento elástico permite que o job de treinamento continue mesmo quando ocorrem falhas de hardware. Isso é feito por uma combinação de recursos do sistema Pathways e lógica de recuperação de modelo definida pelo usuário:

  • Detecção de falha: quando ocorre uma falha de hardware (por exemplo, um worker de TPU falha), o sistema Pathways detecta isso e notifica o job de treinamento do usuário por meio de uma exceção na próxima vez que os dados localizados nesse hardware forem acessados. Essa notificação não causa falhas na carga de trabalho. Ela permite que o código processe a notificação e reconfigure os recursos para continuar o processamento ou sair normalmente.
  • Gerenciador de elasticidade definido pelo usuário: o código do modelo do usuário precisa ser capaz de processar essa exceção. É isso que torna a recuperação específica do modelo.
    • Snapshot: a abordagem mais comum é salvar snapshots do estado do modelo periodicamente. Quando ocorre uma falha, é possível carregar o snapshot mais recente para retomar o treinamento.
    • Reconfiguração: provavelmente será necessário reconfigurar o job de treinamento para ajustar o número de fatias disponíveis. Por exemplo, se uma fatia parar de funcionar, você poderá reduzir o número de fatias ativas em uma até que uma substituição esteja disponível. Para mais informações, consulte Gerenciador elástico.
    • Atualizações do gráfico de dados/computação: o código precisa processar todas as mudanças no número de dispositivos disponíveis para a computação, recriando o gráfico de computação conforme necessário. Isso pode envolver a repartição de dados ou a recompilação do modelo.
  • Papel do Pathways na recuperação: o Pathways fornece os primitivos para oferecer suporte à reconfiguração definida pelo usuário:
    • Substituição de fatias: se uma fatia com falha for substituída, o cliente poderá ser informado quando a nova fatia estiver disponível. Em seguida, o código poderá ser reconfigurado para usar essa nova fatia.
    • Recuperação transparente: o Pathways processa os detalhes de nível inferior da recuperação, como o restabelecimento de conexões com as partes íntegras do cluster.
  • Utilitários em pathwaysutils: um conjunto de utilitários do Pathways definidos em pathways-utils.

Implementar um gerenciador elástico

A maior parte do código que você terá que escrever estará em um gerenciador elástico definido pelo usuário. Esse gerenciador reage a eventos elásticos (como uma fatia de TPU que fica indisponível) recriando a malha e reinicializando o loop de treinamento.

Cada carga de trabalho é única. A complexidade do gerenciador elástico pode ser dimensionada com a complexidade da carga de trabalho. As entradas e saídas do gerenciador precisam ser os argumentos mínimos e retornar os valores necessários para reinicializar o loop de treinamento.

def elastic_handler(elastic_utils, *args, **kwargs):
  mesh = initialize_mesh(**kwargs["mesh_kwargs"])
  initial_state, initial_step, jitted_train_step, other_variables =
      initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"])

  step, snapshot = elastic_utils.get_next_snapshot()
  state = initial_state.replace(**snapshot)

  return state, step, mesh, jitted_train_step, other_variables

Atualizar o loop de treinamento

É necessário fazer as seguintes mudanças no loop de treinamento:

  1. Criar um gerenciador elástico
  2. Encapsular o loop de treinamento em blocos try-except que processam jax.errors.JaxRuntimeErrors
  3. No gerenciador jax.errors.JaxRuntimeError, chame maybe_reshard_down. O gerenciador elástico vai reduzir a fragmentação se o erro estiver relacionado a um evento elástico ou, caso contrário, vai gerar novamente.
  4. Chame maybe_snapshot e maybe_reshard_up no final do loop de treinamento
import pathwaysutils
from pathwaysutils.elastic import manager

pathwaysutils.initialize()

def initialize_mesh(**kwargs):
  ...


def initialize_training_loop(**kwargs):
  ...


def train_loop(
    final_step,
    elastic_manager,
    mesh_kwargs,
    initialize_training_loop_kwargs,
):
  mesh = initialize_mesh(**mesh_kwargs)
  initial_state, initial_step, jitted_train_step, other_variables =
      initialize_training_loop(mesh, **initialize_training_loop_kwargs)

  step = initial_step
  while step < final_step:
    try:
      state = jitted_train_step(state)

      elastic_manager.maybe_snapshot(step=step, snapshot=state)
      handler_returns = elastic_manager.maybe_reshard_up(
          step=step,
          snapshot=state,
          elastic_handler=elastic_handler,
          handler_args=(),
          handler_kwargs=dict(
              mesh_kwargs=mesh_kwargs,
              initialize_training_loop_kwargs=initialize_training_loop_kwargs,
          ),
      )
      if handler_returns:
        state, step, mesh, jitted_train_step, other_variables = handler_returns
      step += 1
    except jax.errors.JaxRuntimeError as error:
      handler_returns = elastic_manager.maybe_reshard_down(
          error=error,
          elastic_handler=elastic_handler,
          handler_args=(),
          handler_kwargs=dict(
              mesh_kwargs=mesh_kwargs,
              initialize_training_loop_kwargs=initialize_training_loop_kwargs,
          ),
      )
      if handler_returns:
        state, step, mesh, jitted_train_step, other_variables = handler_returns

  return state


def main():
  elastic_manager = manager.Manager(
      devices=jax.devices(),
      snapshot_period=10,
      snapshot_buffer_size=1,
      reshard_check_period=5,
      max_elastic_down_event_count=10,
      max_reshard_retry_count=3,
  )

  train_loop(100, elastic_manager, {}, {})

Configurar o gerenciador elástico

O gerenciador elástico pode ser configurado de algumas maneiras diferentes. A frequência de snapshots é determinada pelo período de snapshot. O período de snapshot afeta o número médio de etapas perdidas devido a um evento elástico. O período de verificação de fragmentação determina com que frequência o loop de treinamento vai pesquisar a disponibilidade de fatias. O max_elastic_down_event_count permite definir o número de eventos elásticos devido à perda de fatias que o loop de treinamento vai oferecer suporte. O max_reshard_retry_count especifica o número de vezes que o gerenciador elástico precisa tentar a fragmentação novamente. O gerenciador é um objeto singleton e precisa ser criado apenas uma vez.

Snapshots

Com base na configuração do gerenciador elástico, a função pode fazer um snapshot de dados na memória do host que estará disponível para uso pelo gerenciador elástico durante um evento elástico.

Reduzir a fragmentação

Depois de detectar um jax.errors.JaxRuntimeError, o Pathways vai verificar se o erro é devido a um evento elástico causado por uma fatia perdida. Se for o caso, ele vai chamar o gerenciador elástico em um loop até que a operação seja bem-sucedida ou o número máximo de tentativas seja atingido. Se o erro não for devido a um evento elástico, ele será gerado novamente. Os valores de retorno do gerenciador elástico são transmitidos ao autor da chamada.

Aumentar a fragmentação

Com base na configuração do gerenciador elástico e se houver fatias indisponíveis, o Pathways vai verificar se mais fatias ficaram disponíveis. Se for o caso, ele vai salvar imediatamente um snapshot (se um snapshot pré-existente para a etapa atual ainda não tiver sido feito) e chamar o gerenciador elástico em um loop até que a operação seja bem-sucedida ou o número máximo de tentativas seja atingido. Se a fragmentação ocorrer novamente, os valores de retorno do gerenciador elástico serão transmitidos ao autor da chamada. Caso contrário, None será retornado.

Hot-swap

O hot-swap se refere a um recurso da API JobSet do GKE em que um job de maior prioridade pode assumir rapidamente os recursos de um job de menor prioridade, minimizando o tempo de inatividade e garantindo uma recuperação mais rápida.

Quando um JobSet é criado, o GKE programa a carga de trabalho em várias fatias, conforme especificado na configuração do JobSet. Se ocorrer uma falha de hardware em uma ou mais fatias, os pods afetados serão marcados como com falha. Ao reprogramar esse Jobset, se você tiver optado por manter uma fatia extra no cluster do GKE que possa ser usada para um job de menor prioridade, o sistema JobSet vai remapear a carga de trabalho da fatia com falha do job de maior prioridade para a fatia extra que está sendo usada pelo job de menor prioridade no mesmo cluster do GKE. Esse remapeamento geralmente leva menos de um minuto.

Após a reinicialização do JobSet, o hot-swap pode ocorrer nas seguintes situações:

  1. Modo padrão: se houver fatias de TPU extras e inativas disponíveis no mesmo cluster, o programador do Kubernetes vai priorizar a programação dos jobs reiniciados nessas fatias em vez de esperar que as fatias com falha sejam reparadas. Isso proporciona uma recuperação mais rápida.
  2. Cargas de trabalho heterogêneas: em clusters que executam várias cargas de trabalho com uma PriorityClass do Kubernetes configurada, um JobSet reiniciado pode acionar um hot swap. Se a afinidade do job reiniciado corresponder aos recursos de um job de menor prioridade, o Kubernetes vai antecipar o job de menor prioridade, permitindo que o job de maior prioridade seja iniciado imediatamente. Por exemplo, é possível configurar os pods de worker do Pathways com prioridades diferentes usando PriorityClass.

Para usar prioridades no cluster, defina uma classe de prioridade, por exemplo:

kind: PriorityClass
metadata:
  name: high-prior-job
value: 2000
globalDefault: false
description: "This priority class should be used for high priority job."

Aplique esse YAML ao cluster do GKE:

kubectl apply -f high-prior-job.yaml

Em seguida, anexe a nova classe de prioridade ao job de worker do Pathways adicionando o texto a seguir ao podspec do pod pathways-worker.

priorityClassName: high-prior-job

A seguir