使用 TPU7x (Ironwood) 训练模型

本文档介绍了如何预配 TPU7x 资源,并提供了一个使用 MaxText 和 XPK 部署训练工作负载的示例。

TPU7x 是 Ironwood 系列中的首个版本,也是 Google Cloud的第七代 TPU。Ironwood 世代专为大规模 AI 训练和推理而设计。如需了解详情,请参阅 TPU7x

如需查看更多针对 TPU7x 进行优化的示例,请参阅 GitHub 上的 Ironwood TPU 的训练 recipe

预配 TPU

您可以使用以下方法预配和管理 TPU7x:

  • GKE:您可以使用 GKE 将 TPU 作为加速器池进行预配和管理,以用于容器化机器学习工作负载。使用 Google Cloud CLI 手动创建 GKE 集群实例,以便精确自定义或扩展现有生产 GKE 环境。如需了解详情,请参阅 GKE 中的 TPU 简介
  • GKE 和 XPK:XPK 是一种命令行工具,可简化 GKE 上的集群创建和工作负载执行。它专为机器学习从业者而设计,可用于预配 TPU 和运行训练作业,而无需具备深厚的 Kubernetes 专业知识。使用 XPK 快速创建 GKE 集群并运行工作负载,以进行概念验证和测试。如需了解详情,请参阅 XPK GitHub 仓库
  • GKE 和 TPU Cluster Director:TPU Cluster Director 通过全容量模式预留提供,让您可以完全访问所有预留的容量(无保留),并全面了解 TPU 硬件拓扑、利用率状态和健康状态。如需了解详情,请参阅全容量模式概览

使用 MaxText 和 XPK 部署训练工作负载

使用加速处理套件 (XPK) 创建 GKE 集群以进行概念验证和测试。XPK 是一种命令行工具,旨在简化机器学习工作负载的预配、管理和运行。

以下部分介绍了如何使用 MaxTextXPK 部署训练工作负载。

准备工作

在开始之前,请完成以下步骤:

  • 确保您拥有启用了结算功能的 Google Cloud 项目。
  • 获取 TPU7x 访问权限。如需了解详情,请与您的客户支持团队联系。
  • 确保您用于 XPK 的账号具有 XPK GitHub 仓库中列出的角色。

安装 XPK 和依赖项

  1. 安装 XPK。按照 XPK GitHub 仓库中的说明进行操作。

  2. 按照管理员提供的说明安装 Docker,或按照官方安装说明操作。安装完成后,运行以下命令来配置 Docker 并测试安装:

    gcloud auth configure-docker
    sudo usermod -aG docker $USER # relaunch the terminal and activate venv after running this command
    docker run hello-world # Test Docker
    
  3. 设置以下环境变量:

    export PROJECT_ID=YOUR_PROJECT_ID
    export ZONE=YOUR_ZONE
    export CLUSTER_NAME=YOUR_CLUSTER_NAME
    export ACCELERATOR_TYPE=YOUR_ACCELERATOR_TYPE
    export RESERVATION_NAME=YOUR_RESERVATION_NAME
    export BASE_OUTPUT_DIR="gs://YOUR_BUCKET_NAME"

    替换以下内容:

    • YOUR_PROJECT_ID:您的 Google Cloud 项目 ID。
    • YOUR_ZONE:要在其中创建集群的可用区。对于预览版,仅支持 us-central1-c
    • YOUR_CLUSTER_NAME:新集群的名称。
    • YOUR_ACCELERATOR_TYPE:TPU 版本和拓扑。例如 tpu7x-4x4x8。如需查看受支持的拓扑列表,请参阅受支持的配置
    • YOUR_RESERVATION_NAME:预留的名称。 对于共享预留,请使用 projects/YOUR_PROJECT_NUMBER/reservations/YOUR_RESERVATION_NAME
    • YOUR_BUCKET_NAME:Cloud Storage 存储桶的名称,该存储桶将作为模型训练的输出目录。
  4. 如果您还没有 Cloud Storage 存储桶,请使用以下命令创建一个:

    gcloud storage buckets create ${BASE_OUTPUT_DIR} \
        --project=${PROJECT_ID} \
        --location=US \
        --default-storage-class=STANDARD \
        --uniform-bucket-level-access
    

创建单 NIC、单切片集群

  1. 按照配置 MTU 部分中的说明优化网络配置。

  2. 填充 ${CLUSTER_ARGUMENTS} 变量,您将在 xpk cluster create 命令中使用该变量:

    export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${SUBNET_NAME}"
    
  3. 使用 xpk cluster create 命令创建具有 TPU7x 节点池的 GKE 集群:

    xpk cluster create \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --cluster ${CLUSTER_NAME} \
        --cluster-cpu-machine-type=n1-standard-8 \
        --tpu-type=${ACCELERATOR_TYPE} \
        --reservation=${RESERVATION_NAME} \
        --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"
    

    --cluster-cpu-machine-type 标志设置为 n1-standard-8(或更大值)可确保默认节点池具有足够的 CPU 来运行系统 pod(例如 JobSet webhook),从而防止出现错误。 默认情况下,XPK 使用 e2-standard-16。某些可用区仅支持特定 CPU 类型,因此您可能需要在 n1n2e2 类型之间进行切换。否则,您可能会遇到配额错误

  4. 添加维护排除期以防止集群升级:

    gcloud container clusters update ${CLUSTER_NAME} \
        --zone=${ZONE} \
        --add-maintenance-exclusion-name="no-upgrade-next-month" \
        --add-maintenance-exclusion-start="EXCLUSION_START_TIME" \
        --add-maintenance-exclusion-end="EXCLUSION_END_TIME" \
        --add-maintenance-exclusion-scope="no_upgrades"

    替换以下内容:

    • EXCLUSION_START_TIME:您选择的维护排除期开始时间,采用 YYYY-MM-DDTHH:MM:SSZ 格式。
    • EXCLUSION_END_TIME:您选择的维护排除期结束时间,采用 YYYY-MM-DDTHH:MM:SSZ 格式。

构建或上传 MaxText Docker 映像

您可以使用 MaxText 提供的脚本在本地构建 Docker 映像,也可以使用预构建的映像。

在本地构建

以下命令会将本地目录复制到容器中:

# Make sure you're running on a virtual environment with python3.12. If nothing is printed, you have the correct version.
[[ "$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")' 2>/dev/null)" == "3.12" ]] || { >&2 echo "Error: Python version must be 3.12."; false; }

# Clone MaxText
git clone https://github.com/AI-Hypercomputer/maxtext.git
cd maxtext
git checkout maxtext-tutorial-v1.0.0

# Custom Jax and LibTPU wheels
pip download libtpu==0.0.28.dev20251104+nightly -f "https://storage.googleapis.com/jax-releases/libtpu_releases.html"
pip download --pre jax==0.8.1.dev20251104 jaxlib==0.8.1.dev20251104 --index https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

# Build the Docker image
bash docker_build_dependency_image.sh MODE=custom_wheels

成功执行命令后,您应该会看到在本地创建的名为 maxtext_base_image 的映像。您可以直接在 xpk 工作负载命令中使用本地映像。

上传映像(可选)

按照上一部分中的说明在本地构建 Docker 映像后,您可以使用以下命令将 MaxText Docker 映像上传到注册表:

export CLOUD_IMAGE_NAME="${USER}-maxtext-runner"
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${CLOUD_IMAGE_NAME}

成功执行此命令后,您应该会在 gcr.io 中看到名称为 gcr.io/PROJECT_ID/CLOUD_IMAGE_NAME 的 MaxText 映像。

定义 MaxText 训练命令

准备好在 Docker 容器中运行训练脚本的命令。

MaxText 1B 模型是 MaxText 框架内的一种配置,旨在训练具有约 10 亿参数的语言模型。使用此模型可试验小芯片规模。性能未优化。

export MAXTEXT_COMMAND="JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    python3 src/MaxText/train.py src/MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        per_device_batch_size=2 \
        enable_checkpointing=false \
        gcs_metrics=true \
        run_name=maxtext_xpk \
        steps=30"

部署训练工作负载

运行 xpk workload create 命令以部署训练作业。您必须指定 --base-docker-image 标志以使用 MaxText 基础映像,或者指定 --docker-image 标志和要使用的映像。您可以选择添加 --enable-debug-logs 标志来启用调试日志记录。

xpk workload create \
    --cluster ${CLUSTER_NAME} \
    --base-docker-image maxtext_base_image \
    --workload maxtext-1b-$(date +%H%M) \
    --tpu-type=${ACCELERATOR_TYPE} \
    --zone ${ZONE} \
    --project ${PROJECT_ID} \
    --command "${MAXTEXT_COMMAND}"
    # [--enable-debug-logs]

工作负载名称在集群中必须是唯一的。在此示例中,$(date +%H%M) 附加到工作负载名称中,以确保唯一性。

后续步骤