경로를 사용하는 JAX의 분산된 특성으로 인해 통신 오버헤드로 인해 일부 작업이 제대로 확장되지 않을 수 있습니다. 경로에서는 비동기 디스패치와 같은 기능으로 이러한 오버헤드를 최소화하지만 JAX 워크로드를 경로로 포팅하거나 경로 워크로드로 JAX를 대규모 액셀러레이터로 확장할 때는 몇 가지 사항을 알고 있어야 합니다.
시작하기 전에
다음 사항이 필요합니다.
프로세스 색인
Pathways가 적용된 JAX는 Pathways 클러스터의 모든 기기를 로컬로 취급합니다. 이렇게 하면 기기 관리가 간소화되고 JAX가 사용 가능한 모든 리소스를 활용할 수 있습니다. 실제로 이는 다음을 의미합니다.
jax.process_index()는 모든 기기에서 항상 0입니다.jax.devices()및jax.local_devices()는 전체 작업에서 모든 TPU 기기를 반환합니다.
하드웨어 유형 및 코로케이션
최상의 성능을 위해 모든 Pathways 구성요소와 사용자 작업을 동일한 Google Cloud 클라우드 영역에 배치하세요. IFRT 프록시 및 리소스 관리자와 같은 대형 CPU를 사용합니다. 64개의 vCPU와 256GB 메모리가 제공되는 전용 n2-standard-64를 사용하는 것이 좋습니다.
PathwaysUtils
Pathways-utils는 클라우드 아키텍처의 Pathways에서 JAX 워크로드의 배포 및 실행을 간소화할 수 있는 필수 유틸리티와 도구를 제공하는 Python 기반 GitHub 저장소입니다. 이 패키지는 클라우드 환경에 필요한 적응을 처리하므로 JAX 개발자는 최소한의 플랫폼별 구성으로 핵심 머신러닝 워크플로에 집중할 수 있습니다. 특히 다음과 같은 기능을 제공합니다.
- '프록시' JAX 백엔드: 이 맞춤 백엔드를 사용하면
JAX_PLATFORMS=proxy환경 변수를 설정하여 JAX 애플리케이션이 Pathways 인프라를 사용할 수 있습니다. - 통합 프로파일링 유틸리티: 애플리케이션의 성능을 파악할 수 있는 프로파일링 기능입니다.
jax.profiler.start_trace및jax.profiler.start_server과 같은 표준 JAX 프로파일링 API를 사용하면 JAX 코드뿐만 아니라 기본 Pathways 구성요소도 프로파일링하여 클라우드 환경 내 실행을 전체적으로 파악할 수 있습니다. - Orbax를 사용한 분산 체크포인트: Pathways 환경 내에서 Orbax 라이브러리를 사용할 때 분산 체크포인트를 사용하고 체크포인트를 복원할 수 있는 맞춤 Orbax 체크포인트 핸들러입니다. 이 통합은
pathwaysutils를 가져오는 한 기존 Orbax 체크포인트 코드의 변경 없이 작동하는 것을 목표로 합니다. - 탄력적 학습 기본 요소: Pathways를 사용하여 강력하고 확장 가능한 학습 워크플로를 빌드하는 데 사용할 수 있는 기본 탄력적 학습 기본 요소를 제공합니다. 이러한 기본 요소를 사용하면 학습 작업이 사용 가능한 리소스의 변화에 동적으로 적응하여 클라우드 환경에서 효율성과 복원력을 개선할 수 있습니다.
체크포인트
Orbax는 Cloud Storage를 사용한 분산 체크포인트 및 복원을 위해 Pathways로 철저히 테스트되었습니다. train.py에서 import pathwaysutils; pathwaysutils.initialize()를 호출하면 IFRT 프록시를 통해 체크포인트 작업을 효율적으로 처리하는 맞춤 ArrayHandler가 등록되어 가속기의 Pathways 작업자가 데이터를 직접 저장하고 복원할 수 있습니다.
동일한 위치에 있는 Python
Colocated Python은 사용자가 지정한 Python 코드를 TPU 또는 GPU 호스트에서 직접 실행할 수 있는 오픈소스 JAX API로, 멀티 컨트롤러 JAX에서 더 간단합니다. 이를 통해 데이터 로드 및 체크포인트와 같은 컴퓨팅 집약적인 작업이 클라이언트와 TPU 머신 간의 데이터 전송을 방지할 수 있습니다. 공동 배치된 Python JAX API를 실행하도록 Pathways 클러스터를 구성하려면 공동 배치된 Python README의 안내를 따르세요. 이 안내에서는 Pathways 작업자 옆에 공동 배치된 Python 사이드카를 시작하는 방법을 설명합니다.
데이터 로딩
학습 중에 데이터 세트에서 배치를 반복적으로 로드하여 모델에 피드합니다. 액셀러레이터의 작업량이 고갈되는 것을 방지하려면 호스트 간에 배치를 샤딩하는 효율적인 비동기 데이터 로더가 필요합니다. Pathways로 학습을 실행할 때 데이터 로더는 CPU VM에서 실행되며 (멀티 컨트롤러 설정에 사용되는 TPU VM과 달리) TPU VM에 데이터를 디스패치합니다. 이렇게 하면 데이터를 읽을 때 지연 시간이 길어지지만 CPU 호스트에서 X개의 배치만큼 미리 읽고 읽은 데이터를 TPU에 비동기식으로 디스패치하여 부분적으로 완화됩니다. 이 솔루션은 소규모에서 중규모로 실행할 때 충분합니다.
대규모로 최적의 성능을 얻으려면 공간을 공유하는 Python을 사용하여 가속기에서 직접 데이터 파이프라인을 실행하여 입력 데이터 파이프라인을 공동 배치하는 것이 좋습니다. 이렇게 하면 CPU 병목 현상이 제거되고 데이터 전송을 위해 TPU의 빠른 상호 연결이 활용됩니다.
TFDS 기반 입력 파이프라인을 이전하는 참조 구현은 multihost_dataloading.py의 RemoteIterator 구현에서 확인할 수 있습니다.
이 구현은 공동 배치된 Python JAX API를 사용하여 다중 컨트롤러 JAX와 Pathways 모두에서 분산 방식으로 작동합니다.
Jax 버전 관리
경로 출시 버전은 호환성과 안정성을 보장하기 위해 JAX 버전과 긴밀하게 연결되어 있습니다. 잠재적인 문제를 방지하려면 Pathways 아티팩트와 JAX 버전이 일치하는지 확인하세요. 각 Pathways 출시 버전은 jax-<version> 형식의 태그를 통해 호환되는 JAX 버전을 명시적으로 지정합니다.
컴파일 캐시
Pathways 지속적 컴파일 캐시는 Pathways 서버가 중복 컴파일을 방지하기 위해 컴파일된 XLA 실행 파일을 Cloud Storage와 같은 지속적 위치에 저장할 수 있는 기능입니다. 이 기능은 기본적으로 사용 설정되어 있습니다. 캐시의 위치는 리소스 관리자 및 Pathways 작업자 컨테이너에 --gcs_scratch_location 플래그로 전달됩니다. 연결된 스토리지 비용을 최소화하기 위해 캐시는 수명 주기 정책을 Cloud Storage 위치에 연결합니다. Cloud Storage 버킷당 정책은 50개로 제한됩니다. 따라서 모든 워크로드에서 공통 Cloud Storage 위치를 사용하는 것이 좋습니다.
이 캐시는 Pathways 워크로드에서 pathwaysutils.initialize()에 의해 사용 중지되는 JAX 컴파일 캐시와 유사합니다.
프로파일링
JAX 프로파일러를 사용하여 JAX 프로그램의 트레이스를 생성할 수 있습니다. Pathways에서 지원되는 일반적인 방법에는 두 가지가 있습니다.
두 경우 모두 프로필이 Cloud Storage 버킷에 작성됩니다. Cloud Storage 버킷에 여러 트레이스 파일이 생성되며, 이는 서로 다른 타임스탬프 폴더에 있을 수 있습니다. 예를 들면 다음과 같습니다.
- 트레이스를 호출한 기본 Python 프로세스 (일반적으로 노트북 VM):
<jax-client-vm-name>.xplane.pb - 경로 IFRT 프록시:
client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - 경로 리소스 관리자:
server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb - 경로 작업자:
server.*<tpu-node-name>.xplane.pb
이 추적 파일은 다음 명령어를 실행하여 TensorBoard로 분석할 수 있습니다. TensorBoard 및 모든 프로파일링 도구에 관한 자세한 내용은 프로파일러를 사용하여 TensorFlow 성능 최적화를 참고하세요.
# verify trace files are present gsutil ls -l -r gs://BUCKET/PREFIX # View on tensorboard tensorboard --logdir=gs://BUCKET/PREFIX
다음을 바꿉니다.
BUCKET: 트레이스 파일을 저장할 Cloud Storage 버킷PREFIX: 추적 파일을 저장할 Cloud Storage 버킷 내 경로
프로그래매틱 프로필 캡처
코드 내에서 프로필을 캡처합니다. 프로필은 타임스탬프 디렉터리 아래의 gs://<bucket>/<prefix> 내부에 저장됩니다.
import jax import pathwaysutils pathwaysutils.initialize() jax.profiler.start_trace("gs://BUCKET/PREFIX") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
수동 프로필 캡처
프로필 정보를 수동으로 캡처하려면 Python 코드에서 프로파일러 서버를 시작해야 합니다.
import jax
import pathwaysutils
pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)
# Your JAX code
jax.profiler.stop_server() # this is functinoally a no-op
프로파일러 서버가 실행되는 동안 프로필을 캡처하고 데이터를 타겟 Cloud Storage 위치로 내보낼 수 있습니다.
export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>
curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port
프로그램의 트레이스에서 Compile 및 Execute과 같은 IFRT 프록시 클라이언트 메서드의 타이밍 정보를 확인할 수 있습니다. 컴파일 및 실행 중에 IFRT 프록시 gRPC 서버와의 상호작용을 자세히 설명하는 이러한 이벤트는 GrpcClientSessionUserFuturesWorkQueue라는 스레드에 표시됩니다. 트레이스에서 이 스레드를 검사하면 이러한 작업의 성능에 관한 유용한 정보를 얻을 수 있습니다.
XLA 플래그
경로를 사용하는 경우 경로 프록시 컨테이너에서 XLA 플래그를 설정해야 합니다. XPK 또는 PathwaysJob API를 사용하여 이 작업을 수행할 수 있습니다.
XPK를 사용하는 경우 다음과 같이 XLA 플래그를 설정합니다.
--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"
PathwaysJob API를 사용할 때는 다음과 같이 XLA 플래그를 설정하세요.
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customFlags: - --xla_flag_1=value1 - --xla_flag_2=value2
다음을 바꿉니다.
USER: Google Cloud 사용자 이름value[n]: 설정하려는 XLA 플래그
HLO 덤프
XLA 컴파일러에 제공되는 고수준 옵티마이저 (HLO) 입력을 자세히 살펴보려면 다음과 같이 지정된 Cloud Storage 위치에 HLO를 덤프하도록 Pathways를 구성하면 됩니다.
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: customComponents: - componentType: proxy_server customEnv: - name: XLA_FLAGS value: "--xla_dump_to=gs://your-gcs-bucket/your-desired-prefix/"
다음 단계
- 경로를 사용하여 GKE 클러스터 만들기
- Pathways를 사용한 멀티 호스트 추론
- 경로가 있는 일괄 워크로드
- 학습 과정 대화형 모드
- 학습 과정으로 복원력 있는 학습
- 문제 해결 경로