使用 Pathways 執行批次工作負載

就本文而言,批次工作負載是指執行完畢的 JAX 工作負載,且部署在與 Pathways 叢集相同的 GKE 叢集內,具體來說,是與 Pathways 控制器元件 (IFRT Proxy 伺服器和 Pathways 資源管理工具) 部署在同一位置。JAX 工作負載完成後,路徑叢集元件就會終止。本指南會使用 JAX 訓練工作負載來示範這項功能。

事前準備

請確認您已備妥以下項目:

使用 Maxtext 建構訓練映像檔

MaxText 是 Google 開發的開放原始碼大型語言模型 (LLM) 專案。以 JAX 編寫,專為高效能和可擴充性而設計,可在 Google Cloud TPU 和 GPU 上有效率地執行。

如要使用 OSS GitHub 存放區的最新穩定版 JAX 建構 MaxText Docker 映像檔,請執行下列指令:

git clone https://github.com/AI-Hypercomputer/maxtext
cd maxtext/dependencies/scripts
gcloud config set project PROJECT
bash ./docker_build_dependency_image.sh MODE=stable
gcloud auth configure-docker
bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner # This script needs bash version >= 4.2 to execute.

這個指令會將 MaxText Kubernetes 映像檔推送至 gcr.io/$PROJECT/${USER}_runner。您可以使用這個 Docker 映像檔,透過 Pathways 後端在 TPU 上執行訓練。

使用 PathwaysJob API 執行批次工作負載

下列資訊清單會部署 Pathways 元件,並使用 PathwaysJob API 執行 MaxText 工作負載。工作負載會封裝在 main 容器中,並執行 train.py

將下列 YAML 複製到名為 pathways-job-batch-training.yaml 的檔案,並更新可編輯的值。

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  maxRestarts: MAX_RESTARTS
  workers:
  - type: TPU_MACHINE_TYPE
    topology: TOPOLOGY
    numSlices: WORKLOAD_NODEPOOL_COUNT
  pathwaysDir: gs://BUCKET_NAME
  controller:
    deploymentMode: default
    template:
      spec:
        containers:
        - name: main
          image: gcr.io/PROJECT/USER_runner
          command:
          - bash
          - -c
          - |
            python3 -m MaxText.train /deps/src/MaxText/configs/base.yml \
            base_output_directory=gs://BUCKET_NAME \
            run_name=RUN_NAME \
            per_device_batch_size=1 \
            enable_checkpointing=false \
            remat_policy=full \
            global_parameter_scale=1 \
            steps=20 \
            max_target_length=2048 \
            use_iota_embed=true \
            reuse_example_batch=1 \
            dataset_type=synthetic \
            attention=flash \
            gcs_metrics=True \
            enable_single_controller=True

更改下列內容:

  • USER:您的 Google Cloud 使用者 ID
  • MAX_RESTARTS:工作可重新啟動的次數上限
  • TPU_MACHINE_TYPETPU 機型
  • TOPOLOGY:TPU v4 或更新版本的拓撲。如要進一步瞭解 TPU 版本和支援的拓撲,請參閱「TPU 版本
  • WORKLOAD_NODEPOOL_COUNT:路徑工作負載使用的節點集區數量
  • BUCKET_NAME:用於儲存暫存檔案的 Cloud Storage bucket
  • PROJECT:您的 Google Cloud 專案 ID
  • RUN_NAME:使用者指派的名稱,用於識別工作流程執行作業

您可以按照下列方式部署 PathwaysJob YAML:

kubectl apply -f pathways-job-batch-training.yaml

如要查看先前指令建立的 PathwaysJob 執行個體,請使用:

kubectl get pathwaysjob

輸出內容應如下所示:

NAME             AGE
pathways-trial   9s

如要修改 PathwaysJob 執行個體的屬性,請刪除 PathwaysJob 執行個體、修改 YAML,然後套用該 YAML 來建立新的 PathwaysJob 執行個體。

如要追蹤工作負載的進度,請前往 JAX 容器的 Logs Explorer,然後在「Container Name」篩選器下方選擇 main

您應該會看到類似以下的記錄,表示訓練正在進行。 工作負載會在 30 個步驟後完成。

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

如要刪除 PathwaysJob 執行個體,可以使用下列指令:

kubectl delete -f pathways-job-batch-training.yaml

使用 XPK 執行批次工作負載

現在,您可以使用 XPK 提交預先建構的 Maxtext Docker 映像檔,指令與先前相同。

xpk workload create-pathways \
--workload=WORKLOAD \
--cluster=CLUSTER \
--num-slices=WORKLOAD_NODEPOOL_COUNT \
--tpu-type=TPU_TYPE \
--project=PROJECT \
--zone=ZONE \
--docker-image='gcr.io/PROJECT/USER_runner' \
--command="python3 -m MaxText.train /deps/src/MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"

更改下列內容:

  • WORKLOAD:用於識別工作負載的專屬名稱
  • CLUSTER:GKE 叢集名稱
  • WORKLOAD_NODEPOOL_COUNT:工作可重新啟動的次數上限
  • TPU_TYPE:TPU 類型會指定要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的 TPU 類型,請參閱「TPU 版本」一節。
  • PROJECT:您 Google Cloud 專案的 ID
  • ZONE:您打算執行工作負載的可用區
  • USER:您的 Google Cloud 使用者 ID
  • RUN_NAME:使用者指派的名稱,用於識別工作流程執行作業

您會看到如下所示的輸出:

[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT

使用上一個 XPK 指令輸出內容中的連結,追蹤工作負載的進度。如要篩選 JAX 容器的記錄,請選擇「Container Name」(容器名稱) 篩選器下方的 jax-tpu

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

工作負載會在指定步驟數後完成。如要提前終止,請使用下列指令:

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

後續步驟