使用 Pathways 运行批量工作负载

就本文档而言,批处理工作负载是指执行到完成并部署在与 Pathways 集群相同的 GKE 集群内(具体来说,是与 Pathways 控制器组件 [IFRT 代理服务器和 Pathways 资源管理器] 并行)的 JAX 工作负载。JAX 工作负载完成后,Pathways 集群组件会终止。 本指南使用 JAX 训练工作负载来演示这一点。

准备工作

请确保您已备妥:

使用 Maxtext 构建训练映像

MaxText 是 Google 开发的开源大语言模型 (LLM) 项目。它采用 JAX 编写,旨在实现高性能和可伸缩性,可在 Google Cloud TPU 和 GPU 上高效运行。

如需使用 OSS GitHub 代码库中的最新稳定版 JAX 构建 MaxText Docker 映像,请运行以下命令:

git clone https://github.com/AI-Hypercomputer/maxtext
cd maxtext/dependencies/scripts
gcloud config set project PROJECT
bash ./docker_build_dependency_image.sh MODE=stable
gcloud auth configure-docker
bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner # This script needs bash version >= 4.2 to execute.

此命令会将 MaxText Kubernetes 映像推送到 gcr.io/$PROJECT/${USER}_runner。您可以使用此 Docker 映像通过 Pathways 后端在 TPU 上运行训练。

使用 PathwaysJob API 运行批量工作负载

以下清单使用 PathwaysJob API 部署 Pathways 组件并运行 MaxText 工作负载。工作负载封装在 main 容器中,并执行 train.py

将以下 YAML 复制到名为 pathways-job-batch-training.yaml 的文件中,并更新可编辑的值。

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  maxRestarts: MAX_RESTARTS
  workers:
  - type: TPU_MACHINE_TYPE
    topology: TOPOLOGY
    numSlices: WORKLOAD_NODEPOOL_COUNT
  pathwaysDir: gs://BUCKET_NAME
  controller:
    deploymentMode: default
    template:
      spec:
        containers:
        - name: main
          image: gcr.io/PROJECT/USER_runner
          command:
          - bash
          - -c
          - |
            python3 -m MaxText.train /deps/src/MaxText/configs/base.yml \
            base_output_directory=gs://BUCKET_NAME \
            run_name=RUN_NAME \
            per_device_batch_size=1 \
            enable_checkpointing=false \
            remat_policy=full \
            global_parameter_scale=1 \
            steps=20 \
            max_target_length=2048 \
            use_iota_embed=true \
            reuse_example_batch=1 \
            dataset_type=synthetic \
            attention=flash \
            gcs_metrics=True \
            enable_single_controller=True

替换以下内容:

  • USER:您的 Google Cloud 用户 ID
  • MAX_RESTARTS:作业可重启的最大次数
  • TPU_MACHINE_TYPETPU 机器类型
  • TOPOLOGY:TPU v4 或更高版本的拓扑。如需详细了解 TPU 版本和支持的拓扑,请参阅 TPU 版本
  • WORKLOAD_NODEPOOL_COUNT:Pathways 工作负载使用的节点池数量
  • BUCKET_NAME:用于存储临时文件的 Cloud Storage 存储桶
  • PROJECT:您的 Google Cloud 项目 ID
  • RUN_NAME:用户分配的用于标识工作流运行的名称

您可以按如下方式部署 PathwaysJob YAML:

kubectl apply -f pathways-job-batch-training.yaml

如需查看上一个命令创建的 PathwaysJob 实例,请使用以下命令:

kubectl get pathwaysjob

输出应如下所示:

NAME             AGE
pathways-trial   9s

如需修改 PathwaysJob 实例的属性,请删除 PathwaysJob 实例,修改 YAML 并应用它来创建新的 PathwaysJob 实例。

您可以前往 JAX 容器的日志浏览器,然后在“容器名称”过滤条件下选择 main,以跟踪工作负载的进度。

您应该会看到如下所示的日志,表明训练正在进行中。 工作负载将在 30 步后完成。

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

如需删除 PathwaysJob 实例,您可以使用以下命令:

kubectl delete -f pathways-job-batch-training.yaml

使用 XPK 运行批量工作负载

现在,您可以使用 XPK 提交预构建的 Maxtext Docker 映像,所用命令与之前相同。

xpk workload create-pathways \
--workload=WORKLOAD \
--cluster=CLUSTER \
--num-slices=WORKLOAD_NODEPOOL_COUNT \
--tpu-type=TPU_TYPE \
--project=PROJECT \
--zone=ZONE \
--docker-image='gcr.io/PROJECT/USER_runner' \
--command="python3 -m MaxText.train /deps/src/MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"

替换以下内容:

  • WORKLOAD:用于标识工作负载的唯一名称
  • CLUSTER:GKE 集群的名称
  • WORKLOAD_NODEPOOL_COUNT:作业可重启的最大次数
  • TPU_TYPE:TPU 类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的 TPU 类型,请参阅 TPU 版本
  • PROJECT:您的 Google Cloud 项目 ID
  • ZONE:您计划运行工作负载的可用区
  • USER:您的 Google Cloud 用户 ID
  • RUN_NAME:用户分配的用于标识工作流运行的名称

您应该会看到如下所示的输出:

[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT

使用上一个 XPK 命令的输出中的链接来跟踪工作负载的进度。您可以在“容器名称”过滤条件下选择 jax-tpu,以过滤 JAX 容器的日志。

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

工作负载在指定步数后完成。 如果您想提前终止,请使用以下命令:

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

后续步骤