Como treinar com aceleradores de TPU

A Vertex AI é compatível com o treinamento com vários frameworks e bibliotecas usando uma VM de TPU. Ao configurar recursos de computação, é possível especificar VMs de TPU v2, TPU v3 ou TPU v5e. A TPU v5e é compatível com JAX 0.4.6+, TensorFlow 2.15+ e PyTorch 2.1+. A TPU v6e é compatível com Python 3.10+, JAX 0.4.37+ e PyTorch 2.1+ usando PJRT como ambiente de execução padrão.

Para detalhes sobre como configurar VMs de TPU para treinamento sem servidor da Vertex AI, consulte Configurar recursos de computação para treinamento sem servidor.

Treinamento do TensorFlow

Contêiner pré-criado

Use um contêiner de treinamento pré-criado compatível com TPUs e crie um aplicativo de treinamento em Python.

Contêiner personalizado

Use um contêiner personalizado em que você instalou versões de tensorflow e libtpu especialmente criadas para VMs de TPU. Essas bibliotecas são mantidas pelo serviço do Cloud TPU e estão listadas na documentação sobre Configurações de TPU compatíveis.

Selecione a versão tensorflow e a biblioteca libtpu correspondente de sua escolha. Em seguida, instale-as na imagem do contêiner do Docker ao criar o contêiner.

Por exemplo, se você quiser usar o TensorFlow 2.15, inclua as seguintes instruções no Dockerfile:

  # Download and install `tensorflow`.
  RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

  # Download and install `libtpu`.
  # You must save `libtpu.so` in the '/lib' directory of the container image.
  RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so

  # TensorFlow training on TPU v5e requires the PJRT runtime. To enable the PJRT
  # runtime, configure the following environment variables in your Dockerfile.
  # For details, see https://cloud.google.com/tpu/docs/runtimes#tf-pjrt-support.
  # ENV NEXT_PLUGGABLE_DEVICE_USE_C_API=true
  # ENV TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so

Pod de TPU

O treinamento tensorflow em um TPU Pod requer configuração adicional no contêiner de treinamento. A Vertex AI mantém uma imagem do Docker básica que processa a configuração inicial.

URIs de imagem Versão do Python
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
Python 3.8
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp310:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp310:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp310:latest
Python 3.10

Siga estas etapas para criar seu contêiner personalizado:

  1. Escolha a imagem de base para a versão do Python de sua escolha. As rodas do TensorFlow de TPU para TensorFlow 2.12 e versões anteriores são compatíveis com Python 3.8. O TensorFlow 2.13 e versões mais recentes oferecem suporte ao Python 3.10 ou versões mais recentes. Para conhecer as rodas específicas do TensorFlow, consulte Configurações do Cloud TPU.
  2. Amplie a imagem com seu código treinador e o comando de inicialização.
# Specifies base image and tag
FROM us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp310:latest
WORKDIR /root

# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so

# Copies the trainer code to the docker image.
COPY your-path-to/model.py /root/model.py
COPY your-path-to/trainer.py /root/trainer.py

# The base image is setup so that it runs the CMD that you provide.
# You can provide CMD inside the Dockerfile like as follows.
# Use CMD, not ENTRYPOINT, to avoid accidentally overriding the pod base image's ENTRYPOINT.
# Alternatively, you can pass it as an `args` value in ContainerSpec:
# (https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#containerspec)
CMD ["python3", "trainer.py"]

Treinamento do PyTorch

É possível usar contêineres pré-criados ou personalizados para o PyTorch ao treinar com TPUs.

Contêiner pré-criado

Use um contêiner de treinamento pré-criado compatível com TPUs e crie um aplicativo de treinamento em Python.

Contêiner personalizado

Use um contêiner personalizado em que você instalou a biblioteca PyTorch.

Por exemplo, seu Dockerfile pode ter esta aparência:

FROM python:3.10

# v5e, v6e specific requirement - enable PJRT runtime
ENV PJRT_DEVICE=TPU

# install pytorch and torch_xla
RUN pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0
 -f https://storage.googleapis.com/libtpu-releases/index.html

# Add your artifacts here
COPY trainer.py .

# Run the trainer code
CMD ["python3", "trainer.py"]

Pod de TPU

O treinamento é executado em todos os hosts do Pod de TPU. Consulte Executar o código PyTorch em frações do Pod de TPU.

A Vertex AI aguarda uma resposta de todos os hosts para decidir a conclusão do job.

Treinamento do JAX

Para realizar o treinamento do JAX na Vertex AI, é necessário fornecer uma imagem de contêiner com seu ambiente do JAX.

Contêiner pré-criado

Não há contêineres pré-criados para JAX.

Contêiner personalizado

Use um contêiner personalizado em que você instalou a biblioteca JAX.

Por exemplo, seu Dockerfile pode ter esta aparência:

# Install JAX.
RUN pip install 'jax[tpu]>=0.4.6' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Add your artifacts here
COPY trainer.py trainer.py

# Set an entrypoint.
ENTRYPOINT ["python3", "trainer.py"]

Pod de TPU

O treinamento é executado em todos os hosts do Pod de TPU. Consulte Executar o código JAX em frações do Pod de TPU.

A Vertex AI monitora o primeiro host do Pod de TPU para decidir a conclusão do job. Use o seguinte snippet de código para garantir que todos os hosts saiam ao mesmo tempo:

# Your training logic
...

if jax.process_count() > 1:
  # Make sure all hosts stay up until the end of main.
  x = jnp.ones([jax.local_device_count()])
  x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
  assert x[0] == jax.device_count()

Variáveis de ambiente

A tabela a seguir detalha as variáveis de ambiente que podem ser usadas no contêiner:

Nome Valor
TPU_NODE_NAME my-first-tpu-node
TPU_CONFIG {"project": "tenant-project-xyz", "zone": "us-central1-b", "tpu_node_name": "my-first-tpu-node"}

Conta de serviço personalizada

Uma conta de serviço personalizada pode ser usada para o treinamento da TPU. Para saber como usar uma conta de serviço personalizada, consulte a página sobre como usar uma conta de serviço personalizada.

IP particular (peering de rede VPC) para treinamento

Um IP privado pode ser usado no treinamento da TPU. Consulte a página sobre como usar um IP privado para treinamento sem servidor.

VPC Service Controls

Os projetos ativados pelo VPC Service Controls podem enviar jobs de treinamento de TPU.

Limitações

As limitações a seguir se aplicam quando você treina usando uma VM de TPU:

Tipos de TPU

Consulte Tipos de TPU para mais informações sobre aceleradores de TPU, como limite de memória.