GKE で Ray Train を使用した TPU でのマルチスライスとエラスティック トレーニング

このチュートリアルでは、MaxTextRay Train、Multislice Trillium TPU を使用して、Google Kubernetes Engine(GKE)で Llama 3 70B などの大規模言語モデル(LLM)をトレーニングする方法について説明します。このチュートリアルでは、必要なセカンダリ データセンター ネットワーキングの構成から、32 個の物理 TPU チップに分散されたトレーニング ワークロードを送信して正常に実行するまで、エンドツーエンドの完全なチュートリアルを提供します。

このチュートリアルは、分散マルチホスト TPU スライスで 700 億個のパラメータ モデルをトレーニングする際のメモリとネットワークの課題を克服する方法を学習するプラットフォーム管理者、オペレーター、AI スペシャリストを対象としています。

背景

GKE、KubeRay、MaxText、TPU を組み合わせることで、大規模なモデル トレーニングのための強力でスケーラブルなプラットフォームが実現します。このセクションでは、このガイドで使用されている重要なテクノロジーについて説明します。

JAX

JAX は、アクセラレータ指向の配列計算とプログラム変換のための Python ライブラリです。XLA コンパイラを使用して、アクセラレータで効率的にスケーリングする高度に最適化されたコードを作成します。

MaxText

MaxText は、スケーラビリティとカスタマイズ性を重視して設計された、高パフォーマンスのオープンソース LLM フレームワークです。MaxText は JAX 上に構築されており、Cloud TPU で効率的に実行できるように最適化されています。

TPU

Tensor Processing Unit(TPU)は、機械学習ワークロードを最適化するために Google が作成したカスタム設計のアクセラレータです。汎用 CPU や並列処理 GPU とは異なり、TPU はディープ ラーニングの基盤となる大規模な行列とテンソルの計算に特化しているため、この特定のタスクを効率的に実行できます。TPU の主な利点は、パフォーマンス拡張です。

このチュートリアルでは、マルチスライス デプロイ パターンで第 6 世代 TPU である TPU Trillium を使用します。Cloud TPU マルチスライスは、2 つ以上の Cloud TPU スライスがデータセンター ネットワーク(DCN)上で通信する場所です。マルチスライスは、フルスタックで費用対効果に優れた大規模なトレーニングを可能にします。最大数万の TPU チップまでほぼ線形にスケールアップできます。マルチスライスの詳細については、Cloud TPU マルチスライスの概要をご覧ください。

KubeRay

KubeRay は、Kubernetes で Ray アプリケーションをデプロイ、管理、モニタリングするための統一された方法を提供する Kubernetes オペレーターです。KubeRay オペレーターは、Ray on GKE アドオンを介してインストールおよび管理されます。これは、GKE 上の Ray クラスタをデプロイして管理するおすすめの方法です。

GKE 動的リソース割り当てネットワーク(DRANET)

GKE DRANET(動的リソース割り当てネットワーク)は、高性能ネットワーク デバイスを Pod に動的に接続し、標準の Kubernetes ネットワーキングをバイパスして、DCN で高性能を実現する機能です。

目標

このチュートリアルでは、次の方法を説明します。

  1. 2 つのマルチホスト TPU ノードプールを使用して GKE クラスタを設定します。
  2. クロススライス TPU 通信用のセカンダリ DCN を構成します。
  3. 分散トレーニング環境を管理するように KubeRay を構成します。
  4. ネットワーク アタッチメントに動的リソース割り当て(DRA)を使用して、RayCluster カスタム リソースをデプロイします。
  5. Ray Train の JaxTrainer を利用して、TPU スライス全体で MaxText トレーニング ループをオーケストレートする Python トレーニング スクリプトを作成します。
  6. ベースラインの Llama 3 8B トレーニング ジョブを実行します。
  7. DCN 上で 2D シャーディング(テンソル並列処理と FSDP)を利用して、Llama 3 70B までスケールアップします。

始める前に

  • Google Cloud アカウントにログインします。 Google Cloudを初めて使用する場合は、 アカウントを作成して、実際のシナリオでの Google プロダクトのパフォーマンスを評価してください。新規のお客様には、ワークロードの実行、テスト、デプロイができる無料クレジット $300 分を差し上げます。
  • Google Cloud CLI をインストールします。

  • 外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。

  • gcloud CLI を初期化するには、次のコマンドを実行します。

    gcloud init
  • Google Cloud プロジェクトを作成または選択します

    プロジェクトの選択または作成に必要なロール

    • プロジェクトを選択する: プロジェクトの選択に特定の IAM ロールは必要ありません。ロールが付与されているプロジェクトであれば、どのプロジェクトでも選択できます。
    • プロジェクトを作成する: プロジェクトを作成するには、resourcemanager.projects.create 権限を含むプロジェクト作成者ロール(roles/resourcemanager.projectCreator)が必要です。詳しくは、ロールを付与する方法をご覧ください。
    • Google Cloud プロジェクトを作成します。

      gcloud projects create PROJECT_ID

      PROJECT_ID は、作成する Google Cloud プロジェクトの名前に置き換えます。

    • 作成した Google Cloud プロジェクトを選択します。

      gcloud config set project PROJECT_ID

      PROJECT_ID は、 Google Cloud プロジェクトの名前に置き換えます。

  • Google Cloud プロジェクトに対して課金が有効になっていることを確認します

  • 必要な API を有効にします。

    API を有効にするために必要なロール

    API を有効にするには、serviceusage.services.enable 権限を含む Service Usage 管理者 IAM ロール(roles/serviceusage.serviceUsageAdmin)が必要です。詳しくは、ロールを付与する方法をご覧ください。

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • Google Cloud CLI をインストールします。

  • 外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。

  • gcloud CLI を初期化するには、次のコマンドを実行します。

    gcloud init
  • Google Cloud プロジェクトを作成または選択します

    プロジェクトの選択または作成に必要なロール

    • プロジェクトを選択する: プロジェクトの選択に特定の IAM ロールは必要ありません。ロールが付与されているプロジェクトであれば、どのプロジェクトでも選択できます。
    • プロジェクトを作成する: プロジェクトを作成するには、resourcemanager.projects.create 権限を含むプロジェクト作成者ロール(roles/resourcemanager.projectCreator)が必要です。詳しくは、ロールを付与する方法をご覧ください。
    • Google Cloud プロジェクトを作成します。

      gcloud projects create PROJECT_ID

      PROJECT_ID は、作成する Google Cloud プロジェクトの名前に置き換えます。

    • 作成した Google Cloud プロジェクトを選択します。

      gcloud config set project PROJECT_ID

      PROJECT_ID は、 Google Cloud プロジェクトの名前に置き換えます。

  • Google Cloud プロジェクトに対して課金が有効になっていることを確認します

  • 必要な API を有効にします。

    API を有効にするために必要なロール

    API を有効にするには、serviceusage.services.enable 権限を含む Service Usage 管理者 IAM ロール(roles/serviceusage.serviceUsageAdmin)が必要です。詳しくは、ロールを付与する方法をご覧ください。

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • ユーザー アカウントにロールを付与します。次の IAM ロールごとに次のコマンドを 1 回実行します。 roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editor

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    次のように置き換えます。

    • PROJECT_ID: プロジェクト ID。
    • USER_IDENTIFIER: ユーザー アカウントの識別子。例: myemail@example.com
    • ROLE: ユーザー アカウントに付与する IAM ロール。
  • このチュートリアルでは TPU Trillium(v6e)を利用するため、利用可能なリージョンまたはゾーンを選択します。詳細については、Cloud TPU の割り当てをご覧ください。

環境を準備する

このチュートリアルでは、Cloud Shell を使用します。Cloud Shell には、このチュートリアルで使用する gcloudhelmkubectl コマンドライン ツールがプリインストールされています。

  1. Google Cloud コンソールに移動します。

  2. Google Cloud コンソール ウィンドウの上部にある [Cloud Shell をアクティブにする] Shell をアクティブにするボタン ボタンをクリックします。

    Google Cloud コンソールの新しいフレーム内で Cloud Shell セッションが開き、コマンドライン プロンプトが表示されます。

  3. ターミナルで、kubernetes-engine-samples リポジトリのクローンを作成します。

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    
  4. サンプル ファイルが含まれているディレクトリに移動します。

    cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext
    
  5. Python 仮想環境を作成してアクティブにします。

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  6. Ray CLI をインストールします。

    pip install "ray[default]==2.55.0"
    
  7. 次の環境変数を設定します。

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export CLUSTER_VERSION=1.35.2-gke.1842000
    

    次のように置き換えます。

    • GS_BUCKET: Cloud Storage バケットの名前。
    • KSA_NAME: Kubernetes サービス アカウントの名前。
    • CLUSTER_NAME: 新しいクラスタの名前。
    • REGION: TPU Trillium の容量が使用可能なリージョン。
    • ZONE: TPU Trillium の容量が使用可能なゾーン。詳細については、GKE での TPU の可用性をご覧ください。

Cloud TPU マルチスライスのクラスタ ネットワーキングを構成する

マルチホスト TPU スライス内では、TPU デバイスは高速チップ間相互接続を介して通信します。ただし、マルチスライス ジョブを実行する場合は、TPU スライスが DCN を介して相互に通信する必要があります。標準の Kubernetes Pod ネットワークでは、このトラフィックがボトルネックになる可能性があります。ct6e-standard-4t マシンタイプは、複数の物理ネットワーク インターフェース カード(NIC)を基盤としています。最高のパフォーマンスを実現するには、2 つの追加の VPC ネットワークを作成し、GKE DRANET を使用して Ray Pod に直接接続します。

  1. 大きな最大トレーニング単位(MTU)を使用して、2 つの追加の VPC ネットワークを作成します。

    gcloud compute networks create ${CLUSTER_NAME}-net-1 \
      --subnet-mode=custom \
      --mtu=8896
    
    gcloud compute networks create ${CLUSTER_NAME}-net-2 \
      --subnet-mode=custom \
      --mtu=8896
    
  2. 専用サブネットを作成します。

    gcloud compute networks subnets create tpu-subnet-1 \
      --network=${CLUSTER_NAME}-net-1 \
      --region=${REGION} \
      --range=10.50.0.0/16
    
    gcloud compute networks subnets create tpu-subnet-2 \
      --network=${CLUSTER_NAME}-net-2 \
      --region=${REGION} \
      --range=10.60.0.0/16
    

GKE クラスタを作成する

GKE Autopilot クラスタまたは GKE Standard クラスタの TPU で KubeRay を構成できます。フルマネージドの Kubernetes エクスペリエンスを実現するには、Autopilot クラスタを使用することをおすすめします。ワークロードに最適な GKE の運用モードを選択するには、GKE の運用モードについてをご覧ください。

GKE マネージド DRANET を使用するには、クラスタで Autopilot モードの場合はバージョン 1.35.2-gke.1842000 以降、標準モードの場合は 1.34.1-gke.1829001 以降を使用する必要があります。このチュートリアルでは、バージョン 1.35.2-gke.1842000 を使用します。

Autopilot

  1. Cloud Shell で、次のコマンドを実行します。

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION \
        --cluster-version=${CLUSTER_VERSION}
    
  2. クラスタと通信するには、kubectl を構成します。

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$REGION
    

Standard

  1. Cloud Shell で、次のコマンドを実行して、Ray オペレータ アドオンを有効にする Standard クラスタを作成します。

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator,GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --enable-dataplane-v2 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE \
        --cluster-version=${CLUSTER_VERSION}
    

    このコマンドは GcsFuseCsiDriver も有効にします。これにより、Pod は Cloud Storage バケットをローカル ファイル システムとしてマウントできます。クラスタの作成には数分かかることもあります。

  2. クラスタと通信するには、kubectl を構成します。

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    
  3. GKE DRANET を有効にして、最初のマルチホスト TPU スライス ノードプールを作成します。

    gcloud container node-pools create v6e-16-0 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    
  4. 2 つ目の TPU スライス ノードプールを作成します。

    gcloud container node-pools create v6e-16-1 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    

GKE は、4 つの TPU Trillium(v6e)VM で構成されるノードプールをプロビジョニングします。これらは、4x4 トポロジを持つマルチホスト TPU スライスとして構成されます。このノードプールは、分散トレーニング ワークロードの準備ができています。

Ray オペレーターが有効になっている GKE クラスタは、クラスタに KubeRay と KubeRay TPU Webhook を自動的にインストールします。

Cloud Storage バケットとサービス アカウントを構成する

  1. マルチホスト TPU ノード間で共有されるチェックポイント用の Cloud Storage バケットを作成します。

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. Cloud Storage バケットへのアクセスを有効にするには、Kubernetes サービス アカウントを作成します。

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. Cloud Storage バケットへのアクセスを有効にするには、必要な IAM ポリシー バインディングをサービス アカウントに追加します。

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

トレーニング スクリプトを作成する

maxtext_multi_slice_trainer.py スクリプトは、Ray Train の JaxTrainer を使用して、2 つの TPU スライスで分散 MaxText トレーニング ジョブを実行します。このスクリプトは、8 つのマルチホスト TPU ワーカーのトレーニング環境を構成し、各ワーカーノードで MaxText トレーニング ジョブを実行します。train_loop_per_worker 関数は MaxText のメイン エントリ ポイントをラップし、Ray の分散スケジューラを使用してマルチホスト TPU スライスで MaxText トレーナーを実行します。

import os
from absl import app
import logging
from typing import Sequence
import ray
from ray.train.v2.api.config import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer

def train_loop_per_worker(config):
    import maxtext
    from maxtext.trainers.pre_train.train import main as maxtext_main

    argv = config["argv"]
    maxtext_main(argv)

def main(argv: Sequence[str]):
    # Convert the config file path to an absolute path
    argv = list(argv)
    if len(argv) > 1:
        argv[1] = os.path.abspath(argv[1])

    trainer = JaxTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config={"argv": argv},
        scaling_config=ScalingConfig(
            use_tpu=True,
            num_workers=8,
            topology="4x4",
            accelerator_type="TPU-V6E",
            resources_per_worker={"TPU": 4},
            placement_strategy="SPREAD",
        ),
        run_config=RunConfig(
            name="maxtext_jaxtrainer",
            worker_runtime_env={
                "uv": {
                    # maxtext requires some additional deps
                    "packages": ["maxtext[tpu]==0.2.1"],
                    "uv_pip_install_options": ["--resolution=lowest"]
                },
            },
        ),
    )
    result = trainer.fit()
    logging.info("Training complete!")
    ray.shutdown()

if __name__ == "__main__":
    app.run(main)

上記のスクリプトは、8 つのワーカーと 4x4 のトポロジをリクエストする JaxTrainer インスタンスを定義します。内部的には、Ray は 2 つの TPU スライスに SlicePlacementGroup をプロビジョニングし、Ray Train ワーカーが両方のスライスでアトミックに実行されるようにします(ホストごとに 1 つのワーカー)。

モデルのトレーニング

  1. 現在のディレクトリの ray-cluster.tpu-multi-slice.yaml マニフェストは、RayCluster カスタム リソースを定義します。このマニフェストには、GKE DRANET と Multislice のネットワーク デバイスをプロビジョニングする DRANET ResourceClaimTemplate が含まれています。

    apiVersion: resource.k8s.io/v1
    kind: ResourceClaimTemplate
    metadata:
      name: two-netdev
    spec:
      spec:
        devices:
          requests:
          - name: req-netdev
            exactly:
              deviceClassName: netdev.google.com
              allocationMode: ExactCount
              count: 2
    ---
    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: rayproject/ray:nightly-py312-tpu
                imagePullPolicy: Always
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
            nodeSelector:
              iam.gke.io/gke-metadata-server-enabled: "true"
      workerGroupSpecs:
        - replicas: 2
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: 
            metrics-export-port: "8082"
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              resourceClaims:
              - name: netdev
                resourceClaimTemplateName: two-netdev
              containers:
                - name: ray-worker
                  image: rayproject/ray:nightly-py312-tpu
                  imagePullPolicy: Always
                  resources:
                    claims:
                    - name: netdev
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: MEGASCALE_NUM_SLICES
                      value: "2"
                    - name: MEGASCALE_PORT
                      value: "9915"
                    - name: JAX_PLATFORMS
                      value: tpu,cpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                    - name: LIBTPU_INIT_ARGS
                      value: "--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --megascale_grpc_interface_prefixes=eth1,eth2,lo"
                  securityContext:
                    privileged: true
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
              nodeSelector:
                iam.gke.io/gke-metadata-server-enabled: "true"
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
    

    上記の RayCluster 仕様では、レプリカごとに 8 つのワーカー(numOfHosts: 4)を含む TPU ワーカー グループを 2 つのレプリカで作成します。各ワーカーは 4 つの TPU チップ(google.com/tpu: "4")をリクエストします。ワーカーはそれぞれ、同じコロケーションされたマルチホスト スライスの一部である TPU Trillium ノード(tpu-v6e-slice)でスケジュールされます。KubeRay は、スライス内の 4 つのワーカーすべてをアトミックにスケーリングします。必要な JAX 環境変数とスケジューリング用の Pod アフィニティは、変更用 Webhook を介して GKE によってブートストラップされます。

  2. RayCluster を作成するには、次のマニフェストを適用します。

    envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -
    
  3. クラスタが使用できるようになり、実行中であることを確認します。

    kubectl get rayclusters maxtext-tpu-cluster
    

    出力例を以下に示します。

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY         GPUS   STATUS   AGE
    maxtext-tpu-cluster   8                 8                   72     1579277216Ki   0      ready    2m11s
    
  4. Ray ヘッドサービスを介して Ray ダッシュボードにアクセスするには、ポート転送セッションを確立します。

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. ローカル環境から RayCluster にアクセスできることを確認します。

    ray list nodes --address http://localhost:8265
    

    出力例を以下に示します。

    ray list nodes --address http://localhost:8265
    2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
    2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads.
    
    ======== List: 2026-04-21 10:20:20.945431 ========
    Stats:
    ------------------------------
    Total: 9
    
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP     IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                   LABELS
    0  4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1  10.68.9.5   False           ALIVE                     10.68.9.5    CPU: 8.0                          ray.io/accelerator-type: TPU-V6E
                                                                                                                                    TPU: 4.0                          ray.io/node-group: tpu-group
                                                                                                                                    accelerator_type:TPU-V6E: 1.0     ray.io/node-id: 4f0e4d742...
                                                                                                                                    memory: 186.265 GiB               ray.io/tpu-pod-type: v6e-16
                                                                                                                                    node:10.68.9.5: 1.0               ray.io/tpu-slice-name: tpu-group-0
                                                                                                                                    object_store_memory: 186.265 GiB  ray.io/tpu-topology: 4x4
                                                                                                                                    tpu-group-0: 1.0                  ray.io/tpu-worker-id: '1'
    ...
    6  ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938  10.68.2.9   True            ALIVE                     10.68.2.9    CPU: 8.0                          ray.io/node-group: headgroup
                                                                                                                                    memory: 16.000 GiB                ray.io/node-id: ce7056807...
                                                                                                                                    node:10.68.2.9: 1.0
                                                                                                                                    node:__internal_head__: 1.0
                                                                                                                                    object_store_memory: 4.765 GiB
    ...
    
  6. 基本の MaxText 構成ファイルをダウンロードします。このファイルは、モデルのデフォルトのハイパーパラメータを設定するためにトレーニング スクリプトで必要です。

    curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml
    
  7. JaxTrainer スクリプトを RayCluster に送信し、RayJob が正常に完了することを確認します。

Llama 3 8B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=4 \
      max_target_length=4096 \
      model_name=llama3-8b \
      steps=100 \
      ici_fsdp_parallelism=4 \
      ici_tensor_parallelism=4 \
      run_name=rayjob-multi-slice

Llama 3 70B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=2 \
      max_target_length=4096 \
      model_name=llama3-70b \
      steps=100 \
      ici_tensor_parallelism=4 \
      ici_fsdp_parallelism=4 \
      dcn_fsdp_parallelism=2 \
      dcn_data_parallelism=1 \
      remat_policy=full \
      run_name=rayjob-multi-slice-70b-fsdp

上記のコマンドは、JaxTrainer Ray コードを呼び出す Python スクリプトを RayCluster に送信します。ray job submit コマンドには、モデル構成に渡す MaxText に固有の引数が含まれています。

ターミナルに、Llama 3 70B ジョブの出力が表示されます。

[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]

------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------

Spot VM でマルチスライス エラスティック トレーニングを実行する

TPU などの需要の高いアクセラレータを使用する場合は、Spot VM を利用することでコストを大幅に削減できます。ただし、Spot VM は予期せずプリエンプトされることがあります。

Ray Train はエラスティック トレーニングをサポートしています。これにより、ジョブは参加している TPU スライスの数を動的にスケーリングできます。スライスがプリエンプトされると、Ray はトレーニング ループを一時停止し、残りのワーカーが再編成されるのを待ってから、最新の MaxText チェックポイントから復元し、フットプリントを小さくしてトレーニングを再開します。

エラスティック トレーニングを有効にするには、ScalingConfignum_workers パラメータを静的整数から (minimum_workers, maximum_workers) を表すタプルに変更します。また、RunConfigFailureConfig(max_failures=3) を追加します。これにより、ワーカーがプリエンプトされたときにジョブ全体を失敗させるのではなく、トレーニング ループを最大 3 回再試行するように Ray Train に指示します。

Ray Train スクリプトを更新する

  1. 現在のディレクトリにある maxtext_elastic_trainer.py スクリプトにより、エラスティック トレーニングが有効になります。num_workers=(4,8) が設定されていることに注意してください。これは、16 チップ スライス(4 つのワーカー)が 1 つ以上使用可能な場合は Ray に続行するよう指示し、可能であれば 2 つのスライス(8 つのワーカー)にスケールアップするよう指示します。これには、エラスティック トレーニングを有効にし、再試行回数を定義し、ジョブがプリエンプションを回避できるようにする FailureConfig が含まれています。

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig, FailureConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        import maxtext
        from maxtext.trainers.pre_train.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        # Convert the config file path to an absolute path
        argv = list(argv)
        if len(argv) > 1:
            argv[1] = os.path.abspath(argv[1])
    
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                # This tells Ray to scale the number of workers between 4 and 8 (i.e. 1 to 2 TPU slices).
                num_workers=(4,8),
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                # Define a FailureConfig to enable fault tolerance by automatically restarting failed workers.
                failure_config=FailureConfig(max_failures=3),
                worker_runtime_env={
                    "uv": {
                        # maxtext requires some additional deps
                        "packages": ["maxtext[tpu]==0.2.1"],
                        "uv_pip_install_options": ["--resolution=lowest"]
                    },
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
    
  2. Ray Job CLI を使用してジョブを送信します。チェックポイントが以前の実行と競合しないように、一意の run_name を指定してください。

    ray job submit \
      --address http://localhost:8265 \
      --working-dir . \
      --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
      -- python maxtext_elastic_trainer.py \
          base.yml \
          base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=4 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-elastic-8b
    
  3. トレーニング中にノードの終了またはプリエンプションをシミュレートするには、Pod を削除します。

    kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
    

ターミナルにはワーカーの障害が記録されますが、オーケストレーション コントローラはジョブを存続させ、最小限のトポロジが使用可能になった後、/data/rayjob-elastic-8b/checkpoints チェックポイントから自動的に再開します。

MaxText は再開時にデバイス メッシュを動的に再計算するため、トポロジが縮小したときにチェックポイントの再シャーディングを処理するカスタム ロジックを記述する必要はありません。JAX の Orbax チェックポイントは、トレーニング ループを続行する前に、保存された重みを新しい物理レイアウトに自動的に再シャーディングします。次の出力は、Ray Train コントローラがクラスタで新しく使用可能な TPU リソースを検出し、トレーニング中に 1 つのスライス(4 つのワーカー)から 2 つのスライス(8 つのワーカー)にスケーリング オペレーションを実行することを示しています。

...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048   20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8

クリーンアップ

このチュートリアルで使用したリソースについて、 Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

  1. RayCluster を削除します。

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. GKE クラスタを削除します。

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. Cloud Storage バケットを削除します。

    gsutil rm -r gs://${GS_BUCKET}
    

次のステップ