학습 과정을 통한 탄력적인 학습

경로는 다음과 같은 방식으로 복원력 이점을 제공합니다.

  • 일시중지-재개: 사용자가 맞춤 선점 처리 코드를 작성하지 않아도 되는 선점 알림과 같은 계획된 중단에 대한 허용치입니다.
  • 탄력적 학습: 클라이언트가 비정상 종료되지 않으면서 계획되지 않은 하드웨어 오류에 대한 허용 오차를 제공하지만 사용자가 모델별 복구 코드를 작성해야 합니다.

    시작하기 전에

    다음 사항이 필요합니다.

    정지-재개

    일반적으로 GKE는 포드가 선점되기 전에 액셀러레이터 포드에 선점 알림을 전송합니다. 경로 선점 허용 범위는 모든 클라우드 배포에서 기본적으로 사용 설정되며 경로 가속기 작업은 이러한 알림을 수신합니다.

    선점 알림이 도착하면 Pathways는 먼저 현재 워크로드를 복원할 수 있는지, 즉 Pathways가 워크로드를 투명하게 저장하고 복원할 수 있는지 확인합니다. 이 경우 GKE가 액셀러레이터 작업을 제거하기 전에 현재 상태를 Cloud Storage와 같은 영구 스토리지에 써서 ML 워크로드를 투명하게 일시 중지하려고 시도합니다. 나중에 GKE에서 작업을 다시 예약하면 Pathways는 지속된 상태를 다시 읽어 ML 워크로드를 재개합니다.

    워크로드를 복원할 수 없는 경우 탄력적 학습이 구성되어 있으면 Pathways에서 액셀러레이터 작업을 종료하고 실패를 작업에 전달합니다. 탄력적 학습이 구성되지 않은 경우 GKE는 JobSet 다시 시작 정책에 따라 전체 워크로드를 다시 시작합니다.

    JAX를 사용하여 정의된 일반적인 ML 워크로드는 고대역폭 메모리 (HBM) 스냅샷을 사용하여 복원할 수 있는 상태 비저장 Pathways XLA 구성요소에 의존합니다. JAX 공동 배치 Python API를 사용하여 정의된 것과 같은 특정 ML 워크로드는 상태 저장 Pathways 구성요소를 사용하며, 이러한 구성요소는 복원할 수 없습니다.

    탄력적 학습

    탄력적 학습을 사용하면 하드웨어 장애가 발생해도 학습 작업을 계속할 수 있습니다. 이는 경로 시스템 기능과 사용자 정의 모델 복구 로직의 조합을 통해 달성됩니다.

    • 실패 감지: 하드웨어 오류가 발생하면 (예: TPU 작업자 비정상 종료) Pathways 시스템이 이를 감지하고 해당 하드웨어에 있던 데이터에 액세스할 때마다 예외를 통해 사용자의 학습 작업에 알립니다. 이 알림은 워크로드를 비정상 종료하지 않습니다. 코드가 알림을 처리하고 리소스를 재구성하여 처리를 계속하거나 정상적으로 종료할 수 있습니다.
    • 사용자 정의 탄력성 핸들러: 사용자 모델 코드가 이 예외를 처리할 수 있어야 합니다. 이것이 '모델별 복구'를 만드는 요소입니다.
      • 스냅샷: 가장 일반적인 접근 방식은 모델 상태의 스냅샷을 주기적으로 저장하는 것입니다. 오류가 발생하면 최신 스냅샷에서 로드하여 학습을 재개할 수 있습니다.
      • 재구성: 사용 가능한 슬라이스 수에 맞게 학습 작업을 재구성해야 할 수 있습니다. 예를 들어 한 슬라이스가 작동하지 않으면 대체 슬라이스를 사용할 수 있을 때까지 활성 슬라이스 수를 1개 줄일 수 있습니다. 자세한 내용은 Elastic 핸들러를 참고하세요.
      • 데이터/컴퓨팅 그래프 업데이트: 필요에 따라 컴퓨팅 그래프를 다시 만들어 컴퓨팅에 사용할 수 있는 기기 수의 변경사항을 코드에서 처리해야 합니다. 여기에는 데이터를 다시 파티셔닝하거나 모델을 다시 컴파일하는 것이 포함될 수 있습니다.
    • 복구에서 경로의 역할: 경로는 사용자 정의 재구성을 지원하는 기본 요소를 제공합니다.
      • 슬라이스 교체: 실패한 슬라이스가 교체되면 새 슬라이스를 사용할 수 있게 된 후 클라이언트에 알릴 수 있습니다. 그러면 코드가 이 새 슬라이스를 사용하도록 재구성될 수 있습니다.
      • 투명한 복구: Pathways는 클러스터의 정상 부분에 대한 연결을 다시 설정하는 등 복구의 하위 수준 세부정보를 처리합니다.
    • pathwaysutils의 유틸리티: pathways-utils에 정의된 경로 유틸리티 집합입니다.

    탄력적 핸들러 구현

    작성해야 하는 대부분의 코드는 사용자 정의 탄력적 핸들러에 있습니다. 이 핸들러는 메시를 다시 만들고 학습 루프를 다시 초기화하여 탄력적 이벤트 (예: TPU 슬라이스를 사용할 수 없게 됨)에 반응합니다.

    각 워크로드는 고유합니다. 탄력적 핸들러의 복잡성은 워크로드의 복잡성에 따라 확장될 수 있습니다. 핸들러의 입력과 출력은 학습 루프를 다시 초기화하는 데 필요한 최소 인수와 반환 값이어야 합니다.

    def elastic_handler(elastic_utils, *args, **kwargs):
      mesh = initialize_mesh(**kwargs["mesh_kwargs"])
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"])
    
      step, snapshot = elastic_utils.get_next_snapshot()
      state = initial_state.replace(**snapshot)
    
      return state, step, mesh, jitted_train_step, other_variables
    

    학습 루프 업데이트

    학습 루프를 다음과 같이 변경해야 합니다.

    1. 탄력적 관리자 만들기
    2. jax.errors.JaxRuntimeError를 처리하는 try-except 블록 내에 학습 루프를 래핑합니다.
    3. jax.errors.JaxRuntimeError 핸들러 내에서 maybe_reshard_down를 호출합니다. 탄력적 관리자는 오류가 탄력적 이벤트와 관련된 경우 다시 샤딩하고 그렇지 않은 경우 다시 발생시킵니다.
    4. 학습 루프가 끝나면 maybe_snapshotmaybe_reshard_up 호출
    import pathwaysutils
    from pathwaysutils.elastic import manager
    
    pathwaysutils.initialize()
    
    def initialize_mesh(**kwargs):
      ...
    
    
    def initialize_training_loop(**kwargs):
      ...
    
    
    def train_loop(
        final_step,
        elastic_manager,
        mesh_kwargs,
        initialize_training_loop_kwargs,
    ):
      mesh = initialize_mesh(**mesh_kwargs)
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **initialize_training_loop_kwargs)
    
      step = initial_step
      while step < final_step:
        try:
          state = jitted_train_step(state)
    
          elastic_manager.maybe_snapshot(step=step, snapshot=state)
          handler_returns = elastic_manager.maybe_reshard_up(
              step=step,
              snapshot=state,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
          step += 1
        except jax.errors.JaxRuntimeError as error:
          handler_returns = elastic_manager.maybe_reshard_down(
              error=error,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
    
      return state
    
    
    def main():
      elastic_manager = manager.Manager(
          devices=jax.devices(),
          snapshot_period=10,
          snapshot_buffer_size=1,
          reshard_check_period=5,
          max_elastic_down_event_count=10,
          max_reshard_retry_count=3,
      )
    
      train_loop(100, elastic_manager, {}, {})
    

    탄력적 관리자 구성

    탄력적 관리자는 몇 가지 다른 방식으로 구성할 수 있습니다. 스냅샷 빈도는 스냅샷 기간에 따라 결정됩니다. 스냅샷 기간은 탄력적인 이벤트로 인해 손실된 평균 걸음 수에 영향을 미칩니다. 리샤드 확인 기간은 학습 루프가 슬라이스 사용 가능 여부를 폴링하는 빈도를 결정합니다. max_elastic_down_event_count를 사용하면 학습 루프에서 지원할 슬라이스 손실로 인한 탄력적 이벤트 수를 설정할 수 있습니다. max_reshard_retry_count는 탄력적 관리자가 리샤딩을 재시도해야 하는 횟수를 지정합니다. 관리자는 싱글톤 객체이며 한 번만 생성해야 합니다.

    스냅샷

    탄력적 관리자 구성에 따라 함수는 탄력적 이벤트 중에 탄력적 핸들러가 사용할 수 있는 호스트 메모리에 데이터를 스냅샷할 수 있습니다.

    샤딩 줄이기

    jax.errors.JaxRuntimeError를 포착한 후 Pathways는 손실된 슬라이스로 인한 탄력적 이벤트로 인해 오류가 발생하는지 확인합니다. 그렇다면 성공하거나 최대 재시도 횟수에 도달할 때까지 루프에서 탄력적 핸들러를 호출합니다. 탄력적 이벤트로 인한 오류가 아니면 오류가 다시 발생합니다. 탄력적 핸들러의 반환 값은 호출자에게 전달됩니다.

    샤딩 증가

    탄력적 관리자 구성에 따라 사용할 수 없는 슬라이스가 있는 경우 Pathways는 추가 슬라이스를 사용할 수 있게 되었는지 확인합니다. 이 경우 스냅샷을 즉시 저장하고 (현재 단계의 기존 스냅샷이 아직 촬영되지 않은 경우) 성공하거나 최대 재시도 횟수에 도달할 때까지 루프에서 탄력적 핸들러를 호출합니다. 리샤딩이 발생하면 탄력적 핸들러의 반환 값이 호출자에게 전달됩니다. 그렇지 않으면 None이 반환됩니다.

    핫스왑

    핫 스왑은 우선순위가 높은 작업이 우선순위가 낮은 작업의 리소스를 빠르게 인계받아 다운타임을 최소화하고 더 빠른 복구를 보장하는 GKE JobSet API의 기능을 말합니다.

    JobSet이 생성되면 GKE는 JobSet 구성에 지정된 대로 여러 슬라이스에 워크로드를 예약합니다. 하나 이상의 슬라이드에서 하드웨어 장애가 발생하면 영향을 받는 포드가 실패로 표시됩니다. 이 Jobset을 재예약할 때 우선순위가 낮은 작업에 사용할 수 있는 예비 슬라이스를 GKE 클러스터에 유지하도록 선택한 경우 JobSet 시스템은 우선순위가 높은 작업의 실패한 슬라이스의 워크로드를 동일한 GKE 클러스터 내에서 우선순위가 낮은 작업에서 사용 중인 예비 슬라이스에 다시 매핑합니다. 이 리매핑은 일반적으로 1분 이내에 완료됩니다.

    JobSet이 다시 시작되면 다음과 같은 상황에서 핫스왑이 발생할 수 있습니다.

    1. 기본 모드: 동일한 클러스터 내에서 유휴 예비 TPU 슬라이스를 사용할 수 있는 경우 Kubernetes 스케줄러는 실패한 슬라이스가 복구될 때까지 기다리는 대신 다시 시작된 작업을 이러한 슬라이스에 예약하는 것을 우선시합니다. 이렇게 하면 더 빠르게 복구할 수 있습니다.
    2. 이종 워크로드: 구성된 Kubernetes PriorityClass를 사용하여 여러 워크로드를 실행하는 클러스터에서 다시 시작된 JobSet이 핫 스왑을 트리거할 수 있습니다. 다시 시작된 작업의 어피니티가 우선순위가 낮은 작업의 리소스와 일치하면 Kubernetes는 우선순위가 낮은 작업을 선점하여 우선순위가 높은 작업이 즉시 시작되도록 합니다. 예를 들어 PriorityClass를 사용하여 다양한 우선순위로 Pathways 작업자 포드를 구성할 수 있습니다.

    클러스터에서 우선순위를 사용하려면 다음과 같이 우선순위 클래스를 정의합니다.

    kind: PriorityClass
    metadata:
      name: high-prior-job
    value: 2000
    globalDefault: false
    description: "This priority class should be used for high priority job."
    

    이 YAML을 GKE 클러스터에 적용합니다.

    kubectl apply -f high-prior-job.yaml
    

    다음으로 pathways-worker 포드의 podspec에 다음 텍스트를 추가하여 새로운 우선순위 클래스를 Pathways 작업자 작업에 연결합니다.

    priorityClassName: high-prior-job
    

    다음 단계