Vertex AI는 TPU VM을 사용하여 다양한 프레임워크 및 라이브러리로 학습을 지원합니다. 컴퓨팅 리소스를 구성할 때 TPU v2, TPU v3, TPU v5e VM을 지정할 수 있습니다. TPU v5e는 JAX 0.4.6 이상, TensorFlow 2.15 이상, PyTorch 2.1 이상을 지원합니다. TPU v6e는 PJRT를 기본 런타임으로 사용하는 Python 3.10 이상, JAX 0.4.37 이상, PyTorch 2.1 이상을 지원합니다.
커스텀 학습을 위한 TPU VM 구성에 관한 자세한 내용은 커스텀 학습을 위한 컴퓨팅 리소스 구성을 참고하세요.
TensorFlow 학습
사전 빌드된 컨테이너
TPU를 지원하는 사전 빌드된 학습 컨테이너를 사용하고 Python 학습 애플리케이션을 만듭니다.
커스텀 컨테이너
TPU VM용으로 특별히 빌드된 tensorflow 및 libtpu 버전이 설치된 커스텀 컨테이너를 사용합니다. 이러한 라이브러리는 Cloud TPU 서비스에서 유지 관리되고 지원되는 TPU 구성 문서에 나열되어 있습니다.
원하는 tensorflow 버전과 해당 libtpu 라이브러리를 선택합니다. 그런 후 컨테이너를 빌드할 때 Docker 컨테이너 이미지에 이를 설치합니다.
예를 들어 TensorFlow 2.15를 사용하려면 Dockerfile에 다음 안내를 포함합니다.
  # Download and install `tensorflow`.
  RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
  # Download and install `libtpu`.
  # You must save `libtpu.so` in the '/lib' directory of the container image.
  RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so
  # TensorFlow training on TPU v5e requires the PJRT runtime. To enable the PJRT
  # runtime, configure the following environment variables in your Dockerfile.
  # For details, see https://cloud.google.com/tpu/docs/runtimes#tf-pjrt-support.
  # ENV NEXT_PLUGGABLE_DEVICE_USE_C_API=true
  # ENV TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
TPU Pod
TPU Pod에서 tensorflow를 학습시키려면 학습 컨테이너에 추가 설정이 필요합니다. Vertex AI는 초기 설정을 처리하는 기본 Docker 이미지를 유지합니다.
| 이미지 URI | Python 버전 | 
|---|---|
| 
 | Python 3.8 | 
| 
 | Python 3.10 | 
커스텀 컨테이너를 빌드하는 단계는 다음과 같습니다.
- 선택한 Python 버전의 기본 이미지를 선택합니다. TensorFlow 2.12 이하를 위한 TPU TensorFlow 휠은 Python 3.8을 지원합니다. TensorFlow 2.13 이상은 Python 3.10 이상을 지원합니다. 특정 TensorFlow 휠의 경우에는 Cloud TPU 구성을 참조하세요.
- 트레이너 코드 및 시작 명령어로 이미지를 확장합니다.
# Specifies base image and tag
FROM us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp310:latest
WORKDIR /root
# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so
# Copies the trainer code to the docker image.
COPY your-path-to/model.py /root/model.py
COPY your-path-to/trainer.py /root/trainer.py
# The base image is setup so that it runs the CMD that you provide.
# You can provide CMD inside the Dockerfile like as follows.
# Use CMD, not ENTRYPOINT, to avoid accidentally overriding the pod base image's ENTRYPOINT.
# Alternatively, you can pass it as an `args` value in ContainerSpec:
# (https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#containerspec)
CMD ["python3", "trainer.py"]
PyTorch 학습
TPU로 학습할 때 PyTorch에 사전 빌드된 컨테이너 또는 커스텀 컨테이너를 사용할 수 있습니다.
사전 빌드된 컨테이너
TPU를 지원하는 사전 빌드된 학습 컨테이너를 사용하고 Python 학습 애플리케이션을 만듭니다.
커스텀 컨테이너
PyTorch 라이브러리가 설치된 커스텀 컨테이너를 사용합니다.
예를 들어 Dockerfile은 다음과 같이 보일 수 있습니다.
FROM python:3.10
# v5e, v6e specific requirement - enable PJRT runtime
ENV PJRT_DEVICE=TPU
# install pytorch and torch_xla
RUN pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0
 -f https://storage.googleapis.com/libtpu-releases/index.html
# Add your artifacts here
COPY trainer.py .
# Run the trainer code
CMD ["python3", "trainer.py"]
TPU Pod
학습은 TPU Pod의 모든 호스트에서 실행됩니다(TPU Pod 슬라이스에서 PyTorch 코드 실행 참조).
Vertex AI는 모든 호스트의 응답을 기다려 작업 완료를 결정합니다.
JAX 학습
사전 빌드된 컨테이너
JAX에는 사전 빌드된 컨테이너가 없습니다.
커스텀 컨테이너
JAX 라이브러리가 설치된 커스텀 컨테이너를 사용합니다.
예를 들어 Dockerfile은 다음과 같이 보일 수 있습니다.
# Install JAX.
RUN pip install 'jax[tpu]>=0.4.6' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Add your artifacts here
COPY trainer.py trainer.py
# Set an entrypoint.
ENTRYPOINT ["python3", "trainer.py"]
TPU Pod
학습은 TPU Pod의 모든 호스트에서 실행됩니다(TPU Pod 슬라이스에서 JAX 코드 실행 참조).
Vertex AI는 TPU Pod의 첫 번째 호스트를 감시하여 작업 완료를 결정합니다. 다음 코드 스니펫을 사용하여 모든 호스트가 동시에 종료되는지 확인할 수 있습니다.
# Your training logic
...
if jax.process_count() > 1:
  # Make sure all hosts stay up until the end of main.
  x = jnp.ones([jax.local_device_count()])
  x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
  assert x[0] == jax.device_count()
환경 변수
다음 표에서는 컨테이너 내에서 사용할 수 있는 환경 변수에 대해 자세히 설명합니다.
| 이름 | 값 | 
|---|---|
| TPU_NODE_NAME | my-first-tpu-node | 
| TPU_CONFIG | {"project": "tenant-project-xyz", "zone": "us-central1-b", "tpu_node_name": "my-first-tpu-node"} | 
커스텀 서비스 계정
커스텀 서비스 계정을 TPU 학습에 사용할 수 있습니다. 커스텀 서비스 계정을 사용하는 방법은 커스텀 서비스 계정 사용 방법 페이지를 참고하세요.
학습용 비공개 IP(VPC 네트워크 피어링)
비공개 IP를 TPU 학습에 사용할 수 있습니다. 커스텀 학습에 비공개 IP를 사용하는 방법 페이지를 참고하세요.
VPC 서비스 제어
VPC 서비스 제어가 사용 설정된 프로젝트는 TPU 학습 작업을 제출할 수 있습니다.
제한사항
TPU VM을 사용하여 학습할 때 다음 제한 사항이 적용됩니다.
TPU 유형
메모리 한도와 같이 TPU 가속기에 대한 자세한 내용은 TPU 유형을 참조하세요.