JAX 문제 해결 - TPU

이 가이드에서는 Cloud TPU에서 JAX 모델을 학습시키는 동안 발생할 수 있는 문제를 식별하고 해결하는 데 도움이 되는 JAX 문제 해결 정보를 제공합니다.

Cloud TPU 시작에 대한 보다 일반적인 가이드는 JAX 빠른 시작을 참조하세요.

일반적인 JAX 문제

학습 모델을 개발하거나 JAX로 학습할 때 문제가 발생하면 JAX FAQ를 참조하세요.

JAX로 학습 애플리케이션을 작성할 때 발생할 수 있는 보다 일반적인 프로그래밍 오류는 JAX 오류를 참조하세요.

JAX 성능 프로파일링

JAX 성능 프로파일링에 설명된 도구를 사용하여 TPU 리소스가 사용되는 방식을 파악할 수 있습니다.

메모리 문제 해결

JAX 기기 메모리 프로파일러를 사용하여 메모리가 사용되는 방식을 모니터링할 수 있지만 메모리 사용 방식을 직접 관리할 수는 없습니다.

JAX 기기 메모리 프로파일러를 사용하여 다음을 수행할 수 있습니다.

특정 작업에 TPU 메모리가 할당되는 방식을 지정할 수 없습니다. JAX 특정 TPU 성능 문제에 대한 자세한 내용은 JAX에서 TPU 사용 시 성능 참고사항을 참조하세요.

TPU 문제 해결

다음 섹션에서는 TPU에서 JAX 프로그램을 실행할 때 발생할 수 있는 몇 가지 일반적인 문제를 해결하는 방법을 설명합니다.

TPU가 실행 중인지 확인하려면 어떻게 해야 하나요?

JAX가 'GPU/TPU를 찾을 수 없으며 CPU로 돌아갑니다.'를 출력하지 않는 한 모든 항목이 TPU에서 실행됩니다.

여러 TPU 기기가 표시되는 경우 jax.devices()를 확인하여 TPU가 활성 상태인지 확인하거나, assert jax.devices()[0].platform == 'tpu'를 사용하여 프로그래매틱 방식으로 확인할 수 있습니다.

RuntimeError: 백엔드 'tpu'를 초기화할 수 없음: UNAVAILABLE: 사용 가능한 TPU 플랫폼이 없습니다.

이 런타임 오류 메시지 또는 TPU VM의 /tmp/tpu_logs/tpu_driver.WARNING에서 다음 로그를 확인한 경우: W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx 이는 잘못된 TPU VM 버전에서 실행 중일 가능성을 나타냅니다.

현재 JAX 런타임 버전을 실행 중인지 확인한 후 다시 시도하세요.

TPU 및 GKE 문제 해결

문제 해결을 위해 GKE 워크로드 매니페스트에서 상세 로깅을 사용 설정한 다음 GKE 지원팀에 로그를 제공하세요.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

다음 섹션에서는 TPU 및 GKE 설정과 관련된 오류 메시지와 해결 방법을 설명합니다.

서비스 'jobset-webhook-service'에 사용할 수 있는 엔드포인트 없음

이 오류는 JobSet가 올바르게 설치되지 않았음을 의미합니다. jobset-controller-manager 배포 Kubernetes 포드가 실행 중인지 확인합니다. 자세한 내용은 JobSet 문제 해결 문서를 참고하세요.

TPU 초기화 실패: 연결 실패

GKE 노드 버전이 1.30.4-gke.1348000 이상인지 확인합니다(GKE 1.31은 지원되지 않음).