자동 체크포인트를 사용하여 학습 진행 상태 보존

지금까지 TPU VM에 유지보수가 필요할 때는 사용자가 체크포인트 저장과 같은 진행 보존 작업을 수행할 시간 없이 절차가 즉시 시작되었습니다. 그림 1(a)에서는 이를 보여줍니다.

자동 체크포인트를 사용하는 경우와 사용하지 않는 경우의 호스트 유지보수 영향을 보여주는 다이어그램

그림 1. 자동 체크포인트 기능 그림: (a) 자동 체크포인트를 사용하지 않으면 예정된 유지보수 이벤트가 있을 때 마지막 체크포인트의 학습 진행 상황이 손실됩니다. (b) 자동 체크포인트를 사용하면 예정된 유지보수 이벤트가 있을 때 마지막 체크포인트 이후의 학습 진행 상황을 보존할 수 있습니다.

자동 체크포인트(그림 1(b))를 사용하면 유지보수 이벤트가 발생할 때 예약되지 않은 체크포인트를 저장하도록 코드를 구성하여 학습 진행 상황을 보존할 수 있습니다. 유지보수 이벤트가 발생하면 마지막 체크포인트 이후의 진행 상황이 자동으로 저장됩니다. 이 기능은 단일 슬라이스와 멀티슬라이스 모두에서 작동합니다.

자동 체크포인트 기능은 SIGTERM 신호를 캡처하고 이후에 체크포인트를 저장하는 프레임워크에서 작동합니다. 지원되는 프레임워크는 다음과 같습니다.

자동 체크포인트 사용

자동 체크포인트 기능은 기본적으로 사용 중지되어 있습니다. TPU를 만들거나 또는 큐에 추가된 리소스를 요청한 경우 TPU를 프로비저닝할 때 --autocheckpoint-enabled 플래그를 추가하여 자동 체크포인트를 사용 설정할 수 있습니다. 이 기능을 사용 설정하면 Cloud TPU에서 유지보수 이벤트 알림을 수신하면 다음 단계를 수행합니다.

  1. TPU 기기를 사용하여 프로세스에 전송된 SIGTERM 신호를 캡처합니다.
  2. 프로세스가 종료되거나 5분이 경과될 때까지 기다립니다.
  3. 영향을 받는 리소스에서 유지보수를 수행합니다.

자동 체크포인트에 사용하는 인프라는 ML 프레임워크에 독립적입니다. SIGTERM 신호를 캡처하고 체크포인트 지정 프로세스를 시작할 수 있으면 모든 ML 프레임워크에서 자동 체크포인트를 지원할 수 있습니다.

애플리케이션 코드에서 ML 프레임워크에서 제공한 자동 체크포인트 기능을 사용 설정해야 합니다. 예를 들어 Pax에서는 학습을 시작할 때 명령줄 플래그를 사용 설정해야 합니다. 자세한 내용은 Pax를 사용한 자동 체크포인트 빠른 시작을 참조하세요. 이 과정 중에 프레임워크는 SIGTERM 신호가 수신될 때 예약되지 않은 체크포인트를 저장하고 TPU가 더 이상 사용되지 않으면 영향을 받는 TPU VM이 유지보수됩니다.

빠른 시작: MaxText를 사용한 자동 체크포인트

MaxText는 Cloud TPU를 타겟팅하는 순수 Python/JAX로 작성되어 임의로 확장 가능하고 테스트를 거친 고성능 오픈소스 LLM입니다. MaxText에는 자동 체크포인트 기능을 사용하는 데 필요한 모든 설정이 포함되어 있습니다.

MaxText README 파일에서는 규모에 맞게 MaxText를 실행할 수 있는 두 가지 방법을 설명합니다.

multihost_runner.py를 사용할 경우 큐에 추가된 리소스를 프로비저닝할 때 autocheckpoint-enabled 플래그를 설정하여 자동 체크포인트를 사용 설정합니다.

multihost_job.py를 사용할 경우에는 작업을 실행할 때 ENABLE_AUTOCHECKPOINT=true 명령줄 플래그를 지정하여 자동 체크포인트를 사용 설정합니다.

빠른 시작: 단일 슬라이스에서 Pax로 자동 체크포인트

이 섹션에서는 단일 슬라이스에서 Pax와 함께 자동 체크포인트를 설정하고 사용하는 방법의 예시를 보여줍니다. 다음과 같이 되도록 적절하게 설정합니다.

  • 유지보수 이벤트가 발생하면 체크포인트가 저장됩니다.
  • 체크포인트가 저장된 후 영향을 받는 TPU VM에서 Cloud TPU가 유지보수를 수행합니다.
  • Cloud TPU에서 유지보수를 완료하면 일반적으로 TPU VM을 사용할 수 있습니다.
  1. TPU VM을 만들거나 큐에 추가된 리소스를 요청할 때 autocheckpoint-enabled 플래그를 사용합니다.

    예를 들면 다음과 같습니다.

    1. 환경 변수를 설정합니다.

      export PROJECT_ID=your-project-id
      export TPU_NAME=your-tpu-name
      export ZONE=zone-you-want-to-use
      export ACCELERATOR_TYPE=your-accelerator-type
      export RUNTIME_VERSION=tpu-ubuntu2204-base

      환경 변수 설명

      변수 설명
      PROJECT_ID Google Cloud 프로젝트 ID입니다. 기존 프로젝트를 사용하거나 새 프로젝트를 만듭니다.
      TPU_NAME TPU 이름입니다.
      ZONE TPU VM을 만들 영역입니다. 지원되는 영역에 대한 자세한 내용은 TPU 리전 및 영역을 참조하세요.
      ACCELERATOR_TYPE 액셀러레이터 유형은 만들려는 Cloud TPU의 버전과 크기를 지정합니다. 각 TPU 버전에서 지원되는 액셀러레이터 유형에 대한 자세한 내용은 TPU 버전을 참조하세요.
      RUNTIME_VERSION Cloud TPU 소프트웨어 버전입니다.

    2. 활성 구성에서 프로젝트 ID와 영역을 설정합니다.

      gcloud config set project $PROJECT_ID
      gcloud config set compute/zone $ZONE
    3. TPU를 만듭니다.

      gcloud alpha compute tpus tpu-vm create $TPU_NAME \
          --accelerator-type $ACCELERATOR_TYPE \
          --version $RUNTIME_VERSION \
          --autocheckpoint-enabled
  2. SSH를 사용하여 TPU에 연결합니다.

    gcloud compute tpus tpu-vm ssh $TPU_NAME
    
  3. 단일 슬라이스에 Pax 설치

    자동 체크포인트 기능은 Pax 버전 1.1.0 이상에서 작동합니다. TPU VM에서 jax[tpu] 및 최신 paxml을 설치합니다.

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  4. LmCloudSpmd2B 모델을 구성합니다. 학습 스크립트를 실행하기 전에 ICI_MESH_SHAPE[1, 8, 1]로 변경합니다.

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 8, 1]
  5. 적절한 구성으로 학습을 시작합니다.

    다음 예시에서는 LmCloudSpmd2B 모델에서 자동 체크포인트로 트리거된 체크포인트를 Cloud Storage 버킷에 저장하도록 구성하는 방법을 보여줍니다. your-storage-bucket을 기존 버킷 이름으로 바꾸거나 새 버킷을 만듭니다.

    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    다음 플래그 2개가 명령어에 전달됩니다.

    • jax_fully_async_checkpoint: 이 플래그를 설정하면 orbax.checkpoint.AsyncCheckpointer가 사용됩니다. AsyncCheckpointer 클래스는 학습 스크립트에 SIGTERM 신호가 수신될 때 자동으로 체크포인트를 저장합니다.
    • exit_after_ondemand_checkpoint: 이 플래그를 설정하면 자동 체크포인트가 성공적으로 저장된 후에 TPU 프로세스가 종료되고 유지보수가 즉시 수행됩니다. 이 플래그를 사용하지 않으면 체크포인트가 저장된 후에도 학습이 계속되고 Cloud TPU가 제한 시간(5분)이 발생할 때까지 기다린 후 필요한 유지보수를 수행합니다.

Orbax를 사용한 자동 체크포인트

자동 체크포인트 기능은 MaxText 또는 Pax로 제한되지 않습니다. SIGTERM 신호를 캡처하고 체크포인트 프로세스를 시작할 수 있는 모든 프레임워크는 자동 체크포인트에서 제공한 인프라에서 작동합니다. JAX 사용자를 위한 일반적인 유틸리티 라이브러리를 제공하는 네임스페이스인 Orbax에서도 이러한 기능을 제공합니다.

Orbax 문서의 설명대로 이러한 기능은 기본적으로 orbax.checkpoint.CheckpointManager 사용자에게 사용 설정되어 있습니다. 모든 단계에서 곧 예정된 유지보수 이벤트가 있는지 여부를 자동으로 확인한 후에 호출되는 save 메서드는 단계 번호가 save_interval_steps 배수가 아니더라도 체크포인트를 저장합니다. 또한 GitHub 문서에서는 사용자 코드 수정과 함께 자동 체크포인트를 저장한 후 학습을 종료하는 방법을 설명합니다.