Pathways でバッチ ワークロードを実行する

このドキュメントでは、バッチ ワークロードは、完了まで実行され、Pathways クラスタと同じ GKE クラスタ内、具体的には Pathways コントローラ コンポーネント(IFRT プロキシ サーバーと Pathways リソース マネージャー)とともにデプロイされる JAX ワークロードとして定義されます。JAX ワークロードが完了すると、Pathways クラスタ コンポーネントが終了します。このガイドでは、JAX トレーニング ワークロードを使用してこれを示します。

始める前に

インストールに必要なもの:

Maxtext を使用してトレーニング イメージをビルドする

MaxText は、Google が開発したオープンソースの大規模言語モデル(LLM)プロジェクトです。JAX で記述されており、高パフォーマンスでスケーラブルになるように設計されています。Google Cloud TPU と GPU で効率的に実行されます。

OSS GitHub リポジトリから最新バージョンの安定版 JAX を使用して MaxText Docker イメージをビルドするには、次のコマンドを実行します。

git clone https://github.com/AI-Hypercomputer/maxtext
cd maxtext/dependencies/scripts
gcloud config set project PROJECT
bash ./docker_build_dependency_image.sh MODE=stable
gcloud auth configure-docker
bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner # This script needs bash version >= 4.2 to execute.

このコマンドは、MaxText Kubernetes イメージを gcr.io/$PROJECT/${USER}_runner に push します。この Docker イメージを使用して、Pathways バックエンドで TPU のトレーニングを実行できます。

PathwaysJob API を使用してバッチ ワークロードを実行する

次のマニフェストは、Pathways コンポーネントをデプロイし、PathwaysJob API を使用して MaxText ワークロードを実行します。ワークロードは main コンテナにカプセル化され、train.py を実行します。

次の YAML を pathways-job-batch-training.yaml という名前のファイルにコピーし、編集可能な値を更新します。

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  maxRestarts: MAX_RESTARTS
  workers:
  - type: TPU_MACHINE_TYPE
    topology: TOPOLOGY
    numSlices: WORKLOAD_NODEPOOL_COUNT
  pathwaysDir: gs://BUCKET_NAME
  controller:
    deploymentMode: default
    template:
      spec:
        containers:
        - name: main
          image: gcr.io/PROJECT/USER_runner
          command:
          - bash
          - -c
          - |
            python3 -m MaxText.train /deps/src/MaxText/configs/base.yml \
            base_output_directory=gs://BUCKET_NAME \
            run_name=RUN_NAME \
            per_device_batch_size=1 \
            enable_checkpointing=false \
            remat_policy=full \
            global_parameter_scale=1 \
            steps=20 \
            max_target_length=2048 \
            use_iota_embed=true \
            reuse_example_batch=1 \
            dataset_type=synthetic \
            attention=flash \
            gcs_metrics=True \
            enable_single_controller=True

次のように置き換えます。

  • USER : 実際の Google Cloud ユーザー ID
  • MAX_RESTARTS : Job を再起動できる最大回数
  • TPU_MACHINE_TYPE : TPU マシンタイプ
  • TOPOLOGY : TPU v4 以降のトポロジ。TPU のバージョンとサポートされているトポロジの詳細については、TPU のバージョンをご覧ください。
  • WORKLOAD_NODEPOOL_COUNT : Pathways ワークロードで使用されるノードプールの数
  • BUCKET_NAME : 一時ファイルを保存する Cloud Storage バケット
  • PROJECT : 実際の Google Cloud プロジェクト ID
  • RUN_NAME : ワークフロー実行を識別するためにユーザーが割り当てた名前

PathwaysJob YAML は次のようにデプロイできます。

kubectl apply -f pathways-job-batch-training.yaml

前のコマンドで作成した PathwaysJob インスタンスを表示するには、次のコマンドを使用します。

kubectl get pathwaysjob

出力は次のようになります。

NAME             AGE
pathways-trial   9s

PathwaysJob インスタンスの属性を変更するには、PathwaysJob インスタンスを削除し、YAML を変更して適用し、新しい PathwaysJob インスタンスを作成します。

ワークロードの進行状況を確認するには、[コンテナ名] フィルタで main を選択して、JAX コンテナのログ エクスプローラに移動します。

次のようなログが表示され、トレーニングが進行中であることがわかります。ワークロードは 30 ステップ後に完了します。

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

PathwaysJob インスタンスを削除するには、次のコマンドを使用します。

kubectl delete -f pathways-job-batch-training.yaml

XPK を使用してバッチ ワークロードを実行する

これで、以前に使用したのと同じコマンドを使用して、XPK でビルド済みの Maxtext Docker イメージを送信できます。

xpk workload create-pathways \
--workload=WORKLOAD \
--cluster=CLUSTER \
--num-slices=WORKLOAD_NODEPOOL_COUNT \
--tpu-type=TPU_TYPE \
--project=PROJECT \
--zone=ZONE \
--docker-image='gcr.io/PROJECT/USER_runner' \
--command="python3 -m MaxText.train /deps/src/MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"

次のように置き換えます。

  • WORKLOAD: ワークロードを識別する一意の名前
  • CLUSTER: GKE クラスタの名前
  • WORKLOAD_NODEPOOL_COUNT : ジョブを再起動できる最大回数
  • TPU_TYPE: TPU タイプは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされている TPU タイプの詳細については、TPU のバージョンをご覧ください。
  • PROJECT : 実際の Google Cloud プロジェクト ID
  • ZONE: ワークロードを実行する予定のゾーン
  • USER : 実際の Google Cloud ユーザー ID
  • RUN_NAME : ワークフロー実行を識別するためにユーザーが割り当てた名前

次のような出力が表示されます。

[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT

前の XPK コマンドの出力にあるリンクを使用して、ワークロードの進行状況を確認します。コンテナ名フィルタで jax-tpu を選択すると、JAX コンテナのログをフィルタリングできます。

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

ワークロードは、指定したステップ数で完了します。早期に終了する場合は、次のコマンドを使用します。

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

次のステップ