JAX 워크로드를 Pathways로 포팅

Pathways를 사용하는 JAX는 분산형이므로 통신 오버헤드로 인해 일부 작업의 확장성이 떨어질 수 있습니다. Pathways는 비동기 디스패치와 같은 기능으로 이러한 오버헤드를 최소화하지만 JAX 워크로드를 Pathways로 포팅하거나 Pathways를 사용하는 JAX 워크로드를 많은 수의 가속기로 확장할 때는 몇 가지 사항에 유의해야 합니다.

시작하기 전에

다음 사항이 필요합니다.

프로세스 색인

Pathways를 사용하는 JAX는 Pathways 클러스터의 모든 기기를 로컬로 취급합니다. 이렇게 하면 기기 관리가 간소화되고 JAX에서 사용 가능한 모든 리소스를 활용할 수 있습니다. 실제로 이는 다음을 의미합니다.

  • jax.process_index()는 모든 기기에서 항상 0입니다.
  • jax.devices()jax.local_devices()는 전체 작업에서 모든 TPU 기기를 반환합니다.

하드웨어 유형 및 공동 배치

최상의 성능을 위해 모든 Pathways 구성요소와 사용자 작업을 동일한 Google Cloud 클라우드 영역에 배치합니다. IFRT 프록시 및 리소스 관리자와 같은 대형 CPU를 사용합니다. 64개의 vCPU와 256GB 메모리가 제공되는 전용 n2-standard-64 이상을 사용하는 것이 좋습니다.

PathwaysUtils

Pathways-utils는 Cloud의 Pathways 아키텍처에서 JAX 워크로드의 배포 및 실행을 간소화할 수 있는 필수 유틸리티와 도구를 제공하는 Python 기반 GitHub 저장소입니다. 이 패키지는 클라우드 환경에 필요한 적응을 처리하므로 JAX 개발자는 최소한의 플랫폼별 구성으로 핵심 머신러닝 워크플로에 집중할 수 있습니다. 특히 다음을 제공합니다.

  • '프록시' JAX 백엔드: 이 커스텀 백엔드를 사용하면 JAX_PLATFORMS=proxy 환경 변수를 설정하여 JAX 애플리케이션에서 Pathways 인프라를 사용할 수 있습니다.
  • 통합 프로파일링 유틸리티: 애플리케이션의 성능을 파악할 수 있는 프로파일링 기능입니다. jax.profiler.start_tracejax.profiler.start_server와 같은 표준 JAX 프로파일링 API를 사용하면 JAX 코드뿐만 아니라 기본 Pathways 구성요소도 프로파일링하여 클라우드 환경 내에서 실행을 전체적으로 파악할 수 있습니다.
  • Orbax를 사용한 분산 체크포인트: Pathways 환경 내에서 Orbax 라이브러리를 사용할 때 분산 체크포인트를 사용하고 체크포인트를 복원할 수 있는 커스텀 Orbax 체크포인트 핸들러입니다. 이 통합은 pathwaysutils를 가져오는 한 기존 Orbax 체크포인트 코드를 변경하지 않고 작동하는 것을 목표로 합니다.
  • 탄력적 학습 기본 요소: Pathways를 사용하여 강력하고 확장 가능한 학습 워크플로를 빌드하는 데 사용할 수 있는 기본적인 탄력적 학습 기본 요소를 제공합니다. 이러한 기본 요소를 사용하면 학습 작업이 사용 가능한 리소스의 변경사항에 동적으로 적응하여 클라우드 환경에서 효율성과 복원력을 개선할 수 있습니다.

체크포인트

Orbax는 Cloud Storage를 사용한 분산 체크포인트 및 복원을 위해 Pathways로 철저히 테스트됩니다. `train.py`에서 `import pathwaysutils; pathwaysutils.initialize()`를 호출하면 ArrayHandler를 효율적으로 처리하는 커스텀 IFRT 프록시를 통해 체크포인트 작업을 효율적으로 처리하는 커스텀 ArrayHandler가 등록되어 가속기의 Pathways 작업자가 데이터를 직접 저장하고 복원할 수 있습니다.

공동 배치된 Python

공동 배치된 Python 은 사용자 지정 Python 코드를 TPU 또는 GPU 호스트에서 직접 실행할 수 있는 오픈소스 JAX API로, 다중 컨트롤러 JAX에서 더 간단합니다. 이렇게 하면 데이터 로드 및 체크포인트와 같은 컴퓨팅 집약적인 작업에서 클라이언트와 TPU 머신 간의 데이터 전송을 방지할 수 있습니다. 공동 배치된 Python JAX API를 실행하도록 Pathways 클러스터를 구성하려면 공동 배치된 Python README의 안내를 따르세요. 이 안내에서는 Pathways 작업자와 함께 공동 배치된 Python 사이드카를 시작하는 방법을 설명합니다.

데이터 로딩

학습 중에 데이터 세트에서 배치를 반복적으로 로드하여 모델에 피드합니다. 가속기의 작업량이 고갈되는 것을 방지하려면 호스트 간에 배치를 샤딩하는 효율적인 비동기 데이터 로더가 필요합니다. Pathways로 학습을 실행할 때 데이터 로더는 CPU VM에서 실행되며 (다중 컨트롤러 설정에서 사용되는 TPU VM과 달리) TPU VM에 데이터를 디스패치합니다. 이렇게 하면 데이터 읽기 지연 시간이 길어지지만 CPU 호스트에서 X개의 배치를 미리 읽고 읽은 데이터를 TPU에 비동기식으로 디스패치하여 부분적으로 완화됩니다. 이 솔루션은 소규모에서 중간 규모로 실행할 때 충분합니다.

대규모로 최적의 성능을 얻으려면 공동 배치된 Python을 사용하여 가속기에서 직접 데이터 파이프라인을 실행하여 입력 데이터 파이프라인을 공동 배치하는 것이 좋습니다. 이렇게 하면 CPU 병목 현상이 제거되고 데이터 전송을 위해 TPU의 빠른 상호 연결이 활용됩니다.

TFDS 기반 입력 파이프라인을 이전하는 참조 구현은 RemoteIterator 구현에서 multihost_dataloading.py 확인할 수 있습니다. 이 구현은 공동 배치된 Python JAX API를 사용하여 분산 방식으로 다중 컨트롤러 JAX와 Pathways 모두에서 작동합니다.

Jax 버전 관리

Pathways 출시 버전은 호환성과 안정성을 보장하기 위해 JAX 버전과 긴밀하게 연결되어 있습니다. 잠재적인 문제를 방지하려면 Pathways 아티팩트와 JAX 버전이 일치하는지 확인하세요. 각 Pathways 출시 버전은 호환되는 JAX 버전을 jax-<version>형식의 태그를 통해 명확하게 지정합니다.

컴파일 캐시

Pathways 영구 컴파일 캐시는 Pathways 서버가 중복 컴파일을 방지하기 위해 컴파일된 XLA 실행 파일을 Cloud Storage와 같은 영구 위치에 저장할 수 있는 기능입니다. 이 기능은 기본적으로 사용하도록 설정되어 있습니다. 캐시의 위치는 리소스 관리자 및 Pathways 작업자 컨테이너에 --gcs_scratch_location 플래그로 전달됩니다. 관련 스토리지 비용을 최소한으로 유지하기 위해 캐시는 수명 주기 정책을 Cloud Storage 위치에 연결합니다. Cloud Storage 버킷당 정책은 50개로 제한됩니다. 따라서 모든 워크로드에서 공통 Cloud Storage 위치를 사용하는 것이 좋습니다.

이 캐시는 Pathways 워크로드의 경우 JAX 컴파일 캐시 pathwaysutils.initialize()에 의해 사용 중지되는 것과 유사합니다.

컴파일 캐시에는 다음 Cloud Storage 권한이 필요합니다.

  • storage.buckets.get: 버킷 메타데이터를 검색합니다.
  • storage.buckets.update: Pathways에서 객체 수명 주기 정책을 설정하여 캐시 삭제를 위한 TTL을 적용하는 데 필수적입니다.
  • storage.objects.list: 버킷 내에 있는 기존 캐시 객체를 나열합니다.
  • storage.objects.create: 컴파일된 새 실행 파일을 캐시에 씁니다.
  • storage.objects.get: 버킷에서 캐시된 실행 파일을 읽습니다.

프로파일링

JAX 프로파일러를 사용하여 JAX 프로그램의 트레이스를 생성할 수 있습니다. Pathways에서 지원되는 일반적인 방법은 두 가지입니다.

  • 프로그래매틱
    • JAX 코드에서 프로그래매틱 방식으로 프로필 캡처
  • 수동
    • JAX 코드에서 프로파일러 서버를 시작한 후 온디맨드 프로필 캡처

두 경우 모두 프로필이 Cloud Storage 버킷에 작성됩니다. Cloud Storage 버킷에는 잠재적으로 다른 타임스탬프 폴더 아래에 여러 트레이스 파일이 생성됩니다. 예를 들면 다음과 같습니다.

  • 트레이스를 호출한 기본 Python 프로세스 (일반적으로 노트북 VM): <jax-client-vm-name>.xplane.pb
  • Pathways IFRT 프록시: client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pathways 리소스 관리자: server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • Pathways 작업자: server.*<tpu-node-name>.xplane.pb

이러한 트레이스 파일은 다음 명령어를 실행하여 TensorBoard로 분석할 수 있습니다. TensorBoard 및 모든 프로파일링 도구에 관한 자세한 내용은 Profiler를 사용하여 TensorFlow 성능 최적화를 참조하세요.

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

다음을 바꿉니다.

  • BUCKET : 트레이스 파일을 저장할 Cloud Storage 버킷
  • PREFIX: 트레이스 파일을 저장할 Cloud Storage 버킷 내의 경로

프로그래매틱 프로필 캡처

코드 내부에서 프로필을 캡처합니다. 프로필은 타임스탬프 디렉터리 아래 gs://<bucket>/<prefix> 내에 저장됩니다.

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

수동 프로필 캡처

프로필 정보를 수동으로 캡처하려면 Python 코드에서 프로파일러 서버를 시작해야 합니다.

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functionally a no-op

프로파일러 서버가 실행되는 동안 프로필을 캡처하고 데이터를 대상 Cloud Storage 위치로 내보낼 수 있습니다.

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port

프로그램의 트레이스 내에서 CompileExecute와 같은 IFRT 프록시 클라이언트 메서드의 타이밍 정보를 확인할 수 있습니다. 컴파일 및 실행 중에 IFRT 프록시 gRPC 서버와의 상호작용을 자세히 설명하는 이러한 이벤트는 GrpcClientSessionUserFuturesWorkQueue라는 스레드에 표시됩니다. 트레이스에서 이 스레드를 검사하면 이러한 작업의 성능에 관한 통계를 얻을 수 있습니다.

XLA 플래그

Pathways를 사용하는 경우 pathways-proxy 컨테이너에서 XLA 플래그를 설정해야 합니다. XPK 또는 PathwaysJob API를 사용하여 이 작업을 실행할 수 있습니다.

XPK를 사용하는 경우 다음과 같이 XLA 플래그를 설정합니다.

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

PathwaysJob API를 사용하는 경우 다음과 같이 XLA 플래그를 설정합니다.

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

다음을 바꿉니다.

  • USER : 사용자 Google Cloud 이름
  • value[n]: 설정하려는 XLA 플래그

HLO 덤프

XLA 컴파일러에 제공되는 HLO (High Level Optimizer) 입력을 자세히 살펴보려면 다음과 같이 지정된 Cloud Storage 위치에 HLO를 덤프하도록 Pathways를 구성하면 됩니다.

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-bucket/your-desired-prefix/"

다음 단계