就本文档而言,批处理工作负载是指执行到完成并部署在与 Pathways 集群相同的 GKE 集群内(具体来说,是与 Pathways 控制器组件 [IFRT 代理服务器和 Pathways 资源管理器] 并行)的 JAX 工作负载。JAX 工作负载完成后,Pathways 集群组件会终止。 本指南使用 JAX 训练工作负载来演示这一点。
准备工作
请确保您已备妥:
使用 Maxtext 构建训练映像
MaxText 是 Google 开发的开源大语言模型 (LLM) 项目。它采用 JAX 编写,旨在实现高性能和可伸缩性,可在 Google Cloud TPU 和 GPU 上高效运行。
如需使用 OSS GitHub 代码库中的最新稳定版 JAX 构建 MaxText Docker 映像,请运行以下命令:
git clone https://github.com/AI-Hypercomputer/maxtext cd maxtext/dependencies/scripts gcloud config set project PROJECT bash ./docker_build_dependency_image.sh MODE=stable gcloud auth configure-docker bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner # This script needs bash version >= 4.2 to execute.
此命令会将 MaxText Kubernetes 映像推送到 gcr.io/$PROJECT/${USER}_runner。您可以使用此 Docker 映像通过 Pathways 后端在 TPU 上运行训练。
使用 PathwaysJob API 运行批量工作负载
以下清单使用 PathwaysJob API 部署 Pathways 组件并运行 MaxText 工作负载。工作负载封装在 main 容器中,并执行 train.py。
将以下 YAML 复制到名为 pathways-job-batch-training.yaml 的文件中,并更新可编辑的值。
apiVersion: pathways-job.pathways.domain/v1 kind: PathwaysJob metadata: name: pathways-USER spec: maxRestarts: MAX_RESTARTS workers: - type: TPU_MACHINE_TYPE topology: TOPOLOGY numSlices: WORKLOAD_NODEPOOL_COUNT pathwaysDir: gs://BUCKET_NAME controller: deploymentMode: default template: spec: containers: - name: main image: gcr.io/PROJECT/USER_runner command: - bash - -c - | python3 -m MaxText.train /deps/src/MaxText/configs/base.yml \ base_output_directory=gs://BUCKET_NAME \ run_name=RUN_NAME \ per_device_batch_size=1 \ enable_checkpointing=false \ remat_policy=full \ global_parameter_scale=1 \ steps=20 \ max_target_length=2048 \ use_iota_embed=true \ reuse_example_batch=1 \ dataset_type=synthetic \ attention=flash \ gcs_metrics=True \ enable_single_controller=True
替换以下内容:
USER:您的 Google Cloud 用户 IDMAX_RESTARTS:作业可重启的最大次数TPU_MACHINE_TYPE:TPU 机器类型TOPOLOGY:TPU v4 或更高版本的拓扑。如需详细了解 TPU 版本和支持的拓扑,请参阅 TPU 版本WORKLOAD_NODEPOOL_COUNT:Pathways 工作负载使用的节点池数量BUCKET_NAME:用于存储临时文件的 Cloud Storage 存储桶PROJECT:您的 Google Cloud 项目 IDRUN_NAME:用户分配的用于标识工作流运行的名称
您可以按如下方式部署 PathwaysJob YAML:
kubectl apply -f pathways-job-batch-training.yaml
如需查看上一个命令创建的 PathwaysJob 实例,请使用以下命令:
kubectl get pathwaysjob
输出应如下所示:
NAME AGE pathways-trial 9s
如需修改 PathwaysJob 实例的属性,请删除 PathwaysJob 实例,修改 YAML 并应用它来创建新的 PathwaysJob 实例。
您可以前往 JAX 容器的日志浏览器,然后在“容器名称”过滤条件下选择 main,以跟踪工作负载的进度。
您应该会看到如下所示的日志,表明训练正在进行中。 工作负载将在 30 步后完成。
completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888 completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697 completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641 completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547 completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179 completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776
如需删除 PathwaysJob 实例,您可以使用以下命令:
kubectl delete -f pathways-job-batch-training.yaml
使用 XPK 运行批量工作负载
现在,您可以使用 XPK 提交预构建的 Maxtext Docker 映像,所用命令与之前相同。
xpk workload create-pathways \ --workload=WORKLOAD \ --cluster=CLUSTER \ --num-slices=WORKLOAD_NODEPOOL_COUNT \ --tpu-type=TPU_TYPE \ --project=PROJECT \ --zone=ZONE \ --docker-image='gcr.io/PROJECT/USER_runner' \ --command="python3 -m MaxText.train /deps/src/MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"
替换以下内容:
WORKLOAD:用于标识工作负载的唯一名称CLUSTER:GKE 集群的名称WORKLOAD_NODEPOOL_COUNT:作业可重启的最大次数TPU_TYPE:TPU 类型用于指定您要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的 TPU 类型,请参阅 TPU 版本PROJECT:您的 Google Cloud 项目 IDZONE:您计划运行工作负载的可用区USER:您的 Google Cloud 用户 IDRUN_NAME:用户分配的用于标识工作流运行的名称
您应该会看到如下所示的输出:
[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT
使用上一个 XPK 命令的输出中的链接来跟踪工作负载的进度。您可以在“容器名称”过滤条件下选择 jax-tpu,以过滤 JAX 容器的日志。
completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888 completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697 completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641 completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547 completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179 completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776
工作负载在指定步数后完成。 如果您想提前终止,请使用以下命令:
xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE
后续步骤
- 使用 Pathways 执行多主机推理
- 使用 Pathways 运行交互式工作负载
- 将 JAX 工作负载迁移到 Pathways
- 通过 Pathways 进行弹性训练
- 排查 Pathways on Cloud 问题