Manter o progresso do treinamento usando o Autocheckpoint

Historicamente, quando uma VM de TPU exige manutenção, o procedimento é iniciado imediatamente, sem que haja tempo para que os usuários realizem ações de preservação de progresso, como salvar um checkpoint. Isso é mostrado na Figura 1(a).

Diagrama mostrando o impacto da manutenção de host com e sem o Autocheckpoint

Fig. 1. Ilustração do recurso Autocheckpoint: (a) Sem o Autocheckpoint, o progresso do treinamento desde o último checkpoint é perdido quando há um evento de manutenção programada. (b) Com o Autocheckpoint, o progresso do treinamento desde o último checkpoint pode ser preservado quando há um evento de manutenção programada.

Use o Autocheckpoint, de acordo com a Figura 1(b), para preservar o progresso do treinamento configurando o código para salvar um checkpoint não programado em caso de um evento de manutenção. Quando um evento de manutenção ocorre, o progresso desde o último checkpoint é salvo automaticamente. O recurso funciona com frações únicas e com várias frações.

O recurso Autocheckpoint funciona com frameworks que podem capturar sinais SIGTERM e salvar um checkpoint. Os frameworks aceitos incluem:

Como usar o Autocheckpoint

O Autocheckpoint fica desativado por padrão. Ao criar uma TPU ou solicitar um recurso em fila, é possível ativar o Autocheckpoint adicionando a flag --autocheckpoint-enabled ao provisionar a TPU. Com o recurso ativado, o Cloud TPU realiza as seguintes etapas quando recebe a notificação de um evento de manutenção:

  1. Capturar o sinal SIGTERM enviado ao processo usando o dispositivo de TPU.
  2. Aguardar até que o processo seja encerrado ou até que cinco minutos tenham se passado, o que ocorrer primeiro.
  3. Realizar manutenção nas frações afetadas.

A infraestrutura usada pelo Autocheckpoint é independente do framework de ML. Qualquer framework de ML poderá trabalhar com o Autocheckpoint se conseguir capturar o sinal SIGTERM e iniciar um processo de checkpoint.

No código do aplicativo, é necessário ativar os recursos do Autocheckpoint fornecidos pelo framework de ML. No Pax, por exemplo, isso significa ativar flags de linha de comando ao iniciar o treinamento. Para mais informações, consulte o guia de início rápido do Autocheckpoint com Pax. Os frameworks salvam um checkpoint não programado quando um sinal SIGTERM é recebido, e a VM de TPU afetada passa por manutenção quando não está mais em uso.

Guia de início rápido: Autocheckpoint com MaxText

O MaxText é um LLM de alto desempenho, escalonável de maneira arbitrária, de código aberto e bem testado, escrito em Python/JAX puro para Cloud TPUs. Ele contém toda a configuração necessária para usar o recurso Autocheckpoint.

O arquivo README do MaxText descreve duas maneiras de executar o MaxText em grande escala:

Ao usar multihost_runner.py, ative o Autocheckpoint definindo a flag autocheckpoint-enabled ao provisionar o recurso em fila.

Ao usar multihost_job.py, ative o Autocheckpoint especificando a flag de linha de comando ENABLE_AUTOCHECKPOINT=true ao iniciar o job.

Guia de início rápido: Autocheckpoint com Pax em uma única fração

Nesta seção, mostramos um exemplo de como configurar e usar o Autocheckpoint com o Pax em uma única fração. Com a configuração adequada:

  • Um checkpoint será salvo quando um evento de manutenção ocorrer.
  • O Cloud TPU vai realizar a manutenção nas VMs de TPU afetadas depois que o checkpoint for salvo.
  • Quando a manutenção do Cloud TPU for concluída, você poderá usar a VM de TPU normalmente.
  1. Use a flag autocheckpoint-enabled ao criar a VM de TPU ou solicitar um recurso em fila.

    Por exemplo:

    1. Defina as variáveis de ambiente:

      export PROJECT_ID=your-project-id
      export TPU_NAME=your-tpu-name
      export ZONE=zone-you-want-to-use
      export ACCELERATOR_TYPE=your-accelerator-type
      export RUNTIME_VERSION=tpu-ubuntu2204-base

      Descrições de variáveis de ambiente

      Variável Descrição
      PROJECT_ID O ID do projeto do Google Cloud . Use um projeto atual ou crie um novo.
      TPU_NAME O nome da TPU.
      ZONE A zona em que a VM de TPU será criada. Para mais informações sobre as zonas disponíveis, consulte Zonas e regiões de TPU.
      ACCELERATOR_TYPE O tipo de acelerador especifica a versão e o tamanho do Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores aceitos por cada versão de TPU, consulte Versões de TPU.
      RUNTIME_VERSION A versão do software do Cloud TPU.

    2. Defina o ID do projeto e a zona na configuração ativa:

      gcloud config set project $PROJECT_ID
      gcloud config set compute/zone $ZONE
    3. Crie uma TPU:

      gcloud alpha compute tpus tpu-vm create $TPU_NAME \
          --accelerator-type $ACCELERATOR_TYPE \
          --version $RUNTIME_VERSION \
          --autocheckpoint-enabled
  2. Conecte-se à TPU usando SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME
    
  3. Instale o Pax em uma única fração.

    O recurso Autocheckpoint funciona nas versões 1.1.0 e mais recentes do Pax. Na VM de TPU, instale o jax[tpu] e o paxml mais recente:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  4. Configure o modelo LmCloudSpmd2B. Antes de executar o script de treinamento, mude ICI_MESH_SHAPE para [1, 8, 1]:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 8, 1]
  5. Inicie o treinamento com a configuração adequada.

    O exemplo a seguir mostra como configurar o modelo LmCloudSpmd2B para salvar checkpoints acionados pelo Autocheckpoint em um bucket do Cloud Storage. Substitua your-storage-bucket pelo nome de um bucket atual ou crie um novo bucket.

    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Observe as duas flags transmitidas ao comando:

    • jax_fully_async_checkpoint: com essa flag ativada, orbax.checkpoint.AsyncCheckpointer é usado. A classe AsyncCheckpointer salva automaticamente um checkpoint quando o script de treinamento recebe um sinal SIGTERM.
    • exit_after_ondemand_checkpoint: com essa flag ativada, o processo da TPU é encerrado depois que o Autocheckpoint é salvo, o que aciona imediatamente a realização da manutenção. Se você não usar essa flag, o treinamento vai continuar depois que o checkpoint for salvo e o Cloud TPU vai aguardar um tempo limite (cinco minutos) antes de realizar a manutenção necessária.

Autocheckpoint com Orbax

O recurso Autocheckpoint não está limitado ao MaxText ou ao Pax. Qualquer framework que possa capturar o sinal SIGTERM e iniciar um processo de criação de checkpoint funciona com a infraestrutura fornecida pelo Autocheckpoint. O Orbax, um namespace que oferece bibliotecas de utilitários comuns para usuários do JAX, oferece essas funcionalidades.

Conforme explicado na documentação do Orbax, esses recursos são ativados por padrão para os usuários do orbax.checkpoint.CheckpointManager. O método save, que é chamado após cada etapa, verifica automaticamente se um evento de manutenção está prestes a acontecer e, em caso afirmativo, salva um checkpoint mesmo que o número da etapa não seja um múltiplo de save_interval_steps. A documentação do GitHub também ilustra como fazer com que o treinamento seja encerrado após salvar um Autocheckpoint, com uma modificação no código do usuário.