在 Docker 容器中运行 TPU 工作负载

Docker 容器可将您的代码和所有必需的依赖项合并到一个可分发的软件包中,从而简化应用配置。您可以在 TPU 虚拟机中运行 Docker 容器,以简化 Cloud TPU 应用的配置和共享。本文档介绍了如何为 Cloud TPU 支持的每个机器学习框架设置 Docker 容器。

在 Docker 容器中训练 PyTorch 模型

TPU 设备

  1. 创建 Cloud TPU 虚拟机

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. 使用 SSH 连接到 TPU 虚拟机

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=europe-west4-a
  3. 确保您的 Google Cloud 用户已被授予 Artifact Registry Reader 角色。如需了解详情,请参阅授予 Artifact Registry 角色

  4. 使用夜间 PyTorch/XLA 映像在 TPU 虚拟机中启动容器

    sudo docker run --net=host -ti --rm --name your-container-name --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \
    bash
  5. 配置 TPU 运行时

    PyTorch/XLA 有两种运行时选项:PJRT 和 XRT。除非有使用 XRT 的理由,否则我们建议您使用 PJRT。如需详细了解不同的运行时配置,请参阅 PJRT 运行时文档

    PJRT

    export PJRT_DEVICE=TPU

    XRT

    export XRT_TPU_CONFIG="localservice;0;localhost:51011"
  6. 克隆 PyTorch XLA 仓库

    git clone --recursive https://github.com/pytorch/xla.git
  7. 训练 ResNet50

    python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

训练脚本完成后,请务必清理资源。

  1. 输入 exit 以退出 Docker 容器
  2. 输入 exit 以退出 TPU 虚拟机
  3. 删除 TPU 虚拟机

    gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

TPU 切片

在 TPU 切片上运行 PyTorch 代码时,您必须同时在所有 TPU 工作器上运行代码。完成此操作的一种方法是将 gcloud compute tpus tpu-vm ssh 命令与 --worker=all--command 标志结合使用。以下过程展示了如何创建 Docker 映像,以便更轻松地设置每个 TPU 工作器。

  1. 创建 TPU 虚拟机

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-32 \
    --version=tpu-ubuntu2204-base
  2. 将当前用户添加到 Docker 组

    gcloud compute tpus tpu-vm ssh your-tpu-name \
    --zone=us-central2-b \
    --worker=all \
    --command='sudo usermod -a -G docker $USER'
  3. 克隆 PyTorch XLA 仓库

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=us-central2-b \
    --command="git clone --recursive https://github.com/pytorch/xla.git"
  4. 在所有 TPU 工作器上的容器中运行训练脚本

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=us-central2-b \
    --command="docker run --rm --privileged --net=host  -v ~/xla:/xla -e PJRT_DEVICE=TPU us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 python /xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1"

    Docker 命令标志:

    • --rm 会在容器进程终止后移除容器。
    • --privileged 将 TPU 设备公开给容器。
    • --net=host 会将容器的所有端口绑定到 TPU 虚拟机,以允许 Pod 中的主机之间进行通信。
    • -e 会设置环境变量。

训练脚本完成后,请务必清理资源。

使用以下命令删除 TPU 虚拟机:

gcloud compute tpus tpu-vm delete your-tpu-name \
--zone=us-central2-b

在 Docker 容器中训练 JAX 模型

TPU 设备

  1. 创建 TPU 虚拟机

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  2. 使用 SSH 连接到 TPU 虚拟机

    gcloud compute tpus tpu-vm ssh your-tpu-name  --zone=europe-west4-a
  3. 在 TPU 虚拟机中启动 Docker 守护进程

    sudo systemctl start docker
  4. 启动 Docker 容器

    sudo docker run --net=host -ti --rm --name your-container-name \
    --privileged us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 \
    bash
  5. 安装 JAX

    pip install jax[tpu]
  6. 安装 FLAX

    pip install --upgrade clu
    git clone https://github.com/google/flax.git
    pip install --user -e flax
  7. 安装 tensorflowtensorflow-dataset 软件包

    pip install tensorflow
    pip install tensorflow-datasets
  8. 运行 FLAX MNIST 训练脚本

    cd flax/examples/mnist
    python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

训练脚本完成后,请务必清理资源。

  1. 输入 exit 以退出 Docker 容器
  2. 输入 exit 以退出 TPU 虚拟机
  3. 删除 TPU 虚拟机

    gcloud compute tpus tpu-vm delete your-tpu-name --zone=europe-west4-a

TPU 切片

在 TPU 切片上运行 JAX 代码时,您必须同时在所有 TPU 工作器上运行 JAX 代码。完成此操作的一种方法是将 gcloud compute tpus tpu-vm ssh 命令与 --worker=all--command 标志结合使用。以下过程展示了如何创建 Docker 映像,以便更轻松地设置每个 TPU 工作器。

  1. 在当前目录中创建一个名为 Dockerfile 的文件,并将以下文本粘贴到其中

    FROM python:3.10
    RUN pip install jax[tpu]
    RUN pip install --upgrade clu
    RUN git clone https://github.com/google/flax.git
    RUN pip install --user -e flax
    RUN pip install tensorflow
    RUN pip install tensorflow-datasets
    WORKDIR ./flax/examples/mnist
  2. 准备 Artifact Registry

    gcloud artifacts repositories create your-repo \
    --repository-format=docker \
    --location=europe-west4 --description="Docker repository" \
    --project=your-project
    
    gcloud artifacts repositories list \
    --project=your-project
    
    gcloud auth configure-docker europe-west4-docker.pkg.dev
  3. 构建 Docker 映像

    docker build -t your-image-name .
  4. 在将 Docker 映像推送到 Artifact Registry 之前,为其添加标记。如需详细了解如何使用 Artifact Registry,请参阅使用容器映像

    docker tag your-image-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  5. 将 Docker 映像推送到 Artifact Registry

    docker push europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag
  6. 创建 TPU 虚拟机

    gcloud compute tpus tpu-vm create your-tpu-name \
    --zone=europe-west4-a \
    --accelerator-type=v2-8 \
    --version=tpu-ubuntu2204-base
  7. 在所有 TPU 工作器上从 Artifact Registry 拉取 Docker 映像

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command='sudo usermod -a -G docker ${USER}'
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="gcloud auth configure-docker europe-west4-docker.pkg.dev --quiet"
    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker pull europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag"
  8. 在所有 TPU 工作器上运行容器

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker run -ti -d --privileged --net=host --name your-container-name europe-west4-docker.pkg.dev/your-project/your-repo/your-image-name:your-tag bash"
  9. 在所有 TPU 工作器上运行训练脚本

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker exec --privileged your-container-name python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5"

训练脚本完成后,请务必清理资源。

  1. 关停所有工作器上的容器

    gcloud compute tpus tpu-vm ssh your-tpu-name --worker=all \
    --zone=europe-west4-a \
    --command="docker kill your-container-name"
  2. 删除 TPU 虚拟机

    gcloud compute tpus tpu-vm delete your-tpu-name \
    --zone=europe-west4-a

后续步骤