TPU 슬라이스에서 JAX 코드 실행

이 문서의 명령어를 실행하기 전 계정 및 Cloud TPU 프로젝트 설정의 안내를 따라야 합니다.

단일 TPU 보드에서 JAX 코드를 실행한 후에는 TPU 슬라이스에서 실행하여 코드를 수직 확장할 수 있습니다. TPU 슬라이스는 전용 고속 네트워크 연결을 통해 서로 연결된 여러 TPU 보드입니다. 이 문서는 TPU 슬라이스에서 JAX 코드 실행에 대한 안내입니다. 더 자세한 내용은 멀티 호스트 및 멀티 프로세스 환경에서 JAX 사용을 참조하세요.

필수 역할

TPU를 만들고 SSH를 사용하여 연결하는 데 필요한 권한을 얻으려면 관리자에게 프로젝트에 대한 다음 IAM 역할을 부여해 달라고 요청하세요.

역할 부여 방법에 대한 자세한 내용은 프로젝트, 폴더, 조직에 대한 액세스 관리를 참조하세요.

커스텀 역할 또는 다른 사전 정의된 역할을 통해 필요한 권한을 얻을 수도 있습니다.

Cloud TPU 슬라이스 만들기

  1. 몇 가지 환경 변수를 만듭니다.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5litepod-32
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    환경 변수 설명

    변수 설명
    PROJECT_ID Google Cloud 프로젝트 ID입니다. 기존 프로젝트를 사용하거나 새 프로젝트를 만듭니다.
    TPU_NAME TPU 이름입니다.
    ZONE TPU VM을 만들 영역입니다. 지원되는 영역에 대한 자세한 내용은 TPU 리전 및 영역을 참조하세요.
    ACCELERATOR_TYPE 액셀러레이터 유형은 만들려는 Cloud TPU의 버전과 크기를 지정합니다. 각 TPU 버전에서 지원되는 액셀러레이터 유형에 대한 자세한 내용은 TPU 버전을 참조하세요.
    RUNTIME_VERSION Cloud TPU 소프트웨어 버전입니다.

  2. gcloud 명령어를 사용하여 TPU 슬라이스를 만듭니다. 예를 들어 v5litepod-32 슬라이스를 만들려면 다음 명령어를 사용합니다.

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE}  \
        --version=${RUNTIME_VERSION} 

슬라이스에 JAX 설치

TPU 슬라이스를 만든 후 TPU 슬라이스의 모든 호스트에 JAX를 설치해야 합니다. --worker=all--commamnd 파라미터를 사용하여 gcloud compute tpus tpu-vm ssh 명령어로 이 작업을 수행할 수 있습니다.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

슬라이스에서 JAX 코드 실행

TPU 슬라이스에서 JAX 코드를 실행하려면 TPU 슬라이스의 각 호스트에서 코드를 실행해야 합니다. jax.device_count() 호출은 슬라이스의 각 호스트에서 호출될 때까지 응답을 중지합니다. 다음 예시에서는 TPU 슬라이스에서 JAX 계산을 실행하는 방법을 보여줍니다.

코드 준비

gcloud 버전 344.0.0 이상이 필요합니다(scp 명령어의 경우). gcloud --version을 사용하여 gcloud 버전을 확인하고 필요하면 gcloud components upgrade를 실행합니다.

다음 코드를 사용하여 example.py 파일을 만듭니다.


import jax

# Initialize the slice
jax.distributed.initialize()

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

슬라이스의 모든 TPU 워커 VM에 example.py 복사

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

이전에 scp 명령어를 사용하지 않았으면 다음과 비슷한 오류가 표시될 수 있습니다.

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

오류를 해결하려면 오류 메시지에 표시된 대로 ssh-add 명령어를 실행하고 명령어를 다시 실행합니다.

슬라이스에서 코드 실행

모든 VM에서 example.py 프로그램을 실행합니다.

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

출력(v5litepod-32 슬라이스에서 생성됨):

global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]

삭제

TPU VM 사용이 완료되었으면 다음 단계를 수행하여 리소스를 삭제합니다.

  1. Cloud TPU 및 Compute Engine 리소스를 삭제합니다.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  2. gcloud compute tpus execution-groups list를 실행하여 리소스가 삭제되었는지 확인합니다. 삭제되는 데 몇 분 정도 걸릴 수 있습니다. 다음 명령어 출력에는 이 튜토리얼에서 생성된 리소스가 포함되어서는 안 됩니다.

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}