Training mit TPU-Beschleunigern

Vertex AI unterstützt das Training mit verschiedenen Frameworks und Bibliotheken mit einer TPU-VM. Beim Konfigurieren von Rechenressourcen können Sie TPU v2-, TPU v3- oder TPU v5e-VMs angeben. TPU v5e unterstützt JAX 0.4.6 und höher, TensorFlow 2.15 und höher sowie PyTorch 2.1 und höher. TPU v6e unterstützt Python 3.10 und höher, JAX 0.4.37 und höher sowie PyTorch 2.1 und höher. Dabei wird PJRT als Standardlaufzeit verwendet.

Weitere Informationen zum Konfigurieren von TPU-VMs für das serverlose Training in Vertex AI finden Sie unter Compute-Ressourcen für serverloses Training konfigurieren.

TensorFlow-Training

Vordefinierter Container

Verwenden Sie einen vordefinierten Trainingscontainer, der TPUs unterstützt, und erstellen Sie eine Python-Trainingsanwendung.

Benutzerdefinierter Container

Verwenden Sie einen benutzerdefinierten Container, in dem Sie Versionen von tensorflow und libtpu installiert haben, die speziell für TPU-VMs erstellt wurden. Diese Bibliotheken werden vom Cloud TPU-Dienst verwaltet und in der Dokumentation Unterstützte TPU-Konfigurationen aufgeführt.

Wählen Sie die gewünschte tensorflow-Version und die entsprechende libtpu-Bibliothek aus. Installieren Sie diese dann beim Erstellen des Containers im Docker-Container-Image.

Wenn Sie beispielsweise TensorFlow 2.15 verwenden möchten, fügen Sie Ihrem Dockerfile die folgende Anleitung hinzu:

  # Download and install `tensorflow`.
  RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

  # Download and install `libtpu`.
  # You must save `libtpu.so` in the '/lib' directory of the container image.
  RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so

  # TensorFlow training on TPU v5e requires the PJRT runtime. To enable the PJRT
  # runtime, configure the following environment variables in your Dockerfile.
  # For details, see https://cloud.google.com/tpu/docs/runtimes#tf-pjrt-support.
  # ENV NEXT_PLUGGABLE_DEVICE_USE_C_API=true
  # ENV TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so

TPU-Pod

Das Training von tensorflow auf einem TPU Pod erfordert eine zusätzliche Einrichtung im Trainingscontainer. Vertex AI verwaltet ein Basis-Docker-Image, das die Ersteinrichtung übernimmt.

Image-URIs Python-Version
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp38:latest
Python 3.8
  • us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp310:latest
  • europe-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp310:latest
  • asia-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp310:latest
Python 3.10

So erstellen Sie Ihren benutzerdefinierten Container:

  1. Wählen Sie das Basis-Image für die Python-Version Ihrer Wahl aus. TPU TensorFlow Wheels für TensorFlow 2.12 und niedriger unterstützen Python 3.8. TensorFlow 2.13 und höher unterstützen Python 3.10 oder höher. Informationen zu den jeweiligen TensorFlow-Räumen finden Sie unter Cloud TPU-Konfigurationen.
  2. Erweitern Sie das Image mit Ihrem Trainercode und dem Startbefehl.
# Specifies base image and tag
FROM us-docker.pkg.dev/vertex-ai/training/tf-tpu-pod-base-cp310:latest
WORKDIR /root

# Download and install `tensorflow`.
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.15.0/tensorflow-2.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

# Download and install `libtpu`.
# You must save `libtpu.so` in the '/lib' directory of the container image.
RUN curl -L https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.9.0/libtpu.so -o /lib/libtpu.so

# Copies the trainer code to the docker image.
COPY your-path-to/model.py /root/model.py
COPY your-path-to/trainer.py /root/trainer.py

# The base image is setup so that it runs the CMD that you provide.
# You can provide CMD inside the Dockerfile like as follows.
# Use CMD, not ENTRYPOINT, to avoid accidentally overriding the pod base image's ENTRYPOINT.
# Alternatively, you can pass it as an `args` value in ContainerSpec:
# (https://cloud.google.com/vertex-ai/docs/reference/rest/v1/CustomJobSpec#containerspec)
CMD ["python3", "trainer.py"]

PyTorch-Training

Sie können für das Training mit TPUs vordefinierte oder benutzerdefinierte Container für PyTorch verwenden.

Vordefinierter Container

Verwenden Sie einen vordefinierten Trainingscontainer, der TPUs unterstützt, und erstellen Sie eine Python-Trainingsanwendung.

Benutzerdefinierter Container

Verwenden Sie einen benutzerdefinierten Container, in dem Sie die PyTorch-Bibliothek installiert haben.

Ihr Dockerfile könnte beispielsweise so aussehen:

FROM python:3.10

# v5e, v6e specific requirement - enable PJRT runtime
ENV PJRT_DEVICE=TPU

# install pytorch and torch_xla
RUN pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0
 -f https://storage.googleapis.com/libtpu-releases/index.html

# Add your artifacts here
COPY trainer.py .

# Run the trainer code
CMD ["python3", "trainer.py"]

TPU-Pod

Das Training wird auf allen Hosts des TPU-Pods ausgeführt (siehe PyTorch-Code auf TPU-Pod-Slices ausführen).

Vertex AI wartet auf eine Antwort von allen Hosts, um den Abschluss des Jobs zu entscheiden.

JAX-Training

Wenn Sie JAX-Training in Vertex AI ausführen möchten, müssen Sie ein Container-Image mit Ihrer JAX-Umgebung bereitstellen.

Vordefinierter Container

Es gibt keine vordefinierten Container für JAX.

Benutzerdefinierter Container

Verwenden Sie einen benutzerdefinierten Container, in dem Sie die JAX-Bibliothek installiert haben.

Ihr Dockerfile könnte beispielsweise so aussehen:

# Install JAX.
RUN pip install 'jax[tpu]>=0.4.6' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Add your artifacts here
COPY trainer.py trainer.py

# Set an entrypoint.
ENTRYPOINT ["python3", "trainer.py"]

TPU-Pod

Das Training wird auf allen Hosts des TPU-Pods ausgeführt (siehe JAX-Code auf TPU-Pod-Slices ausführen).

Vertex AI überwacht den ersten Host des TPU-Pods, um über den Abschluss des Jobs zu entscheiden. Mit dem folgenden Code-Snippet können Sie dafür sorgen, dass alle Hosts gleichzeitig beendet werden:

# Your training logic
...

if jax.process_count() > 1:
  # Make sure all hosts stay up until the end of main.
  x = jnp.ones([jax.local_device_count()])
  x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
  assert x[0] == jax.device_count()

Umgebungsvariablen

In der folgenden Tabelle werden die Umgebungsvariablen beschrieben, die Sie im Container verwenden können:

Name Wert
TPU_NODE_NAME my-first-tpu-node
TPU_CONFIG {"project": "tenant-project-xyz", "zone": "us-central1-b", "tpu_node_name": "my-first-tpu-node"}

Benutzerdefiniertes Dienstkonto

Ein benutzerdefiniertes Dienstkonto kann für das TPU-Training verwendet werden. Informationen zur Verwendung eines benutzerdefinierten Dienstkontos finden Sie auf der Seite zur Verwendung eines benutzerdefinierten Dienstkontos.

Private IP-Adresse (VPC-Netzwerk-Peering) für das Training

Eine private IP-Adresse kann für das TPU-Training verwendet werden. Informationen dazu finden Sie auf der Seite Private IP-Adresse für serverloses Training verwenden.

VPC Service Controls

VPC Service Controls-fähige Projekte können TPU-Trainingsjobs senden.

Beschränkungen

Beim Trainieren mit einer TPU-VM gelten die folgenden Einschränkungen:

TPU-Typen

Weitere Informationen zu TPU-Beschleunigern wie das Arbeitsspeicherlimit finden Sie unter TPU-Typen.