このチュートリアルでは、MaxText、Ray Train、Multislice Trillium TPU を使用して、Google Kubernetes Engine(GKE)で Llama 3 70B などの大規模言語モデル(LLM)をトレーニングする方法について説明します。このチュートリアルでは、必要なセカンダリ データセンター ネットワーキングの構成から、32 個の物理 TPU チップに分散されたトレーニング ワークロードを送信して正常に実行するまで、エンドツーエンドの完全なチュートリアルを提供します。
このチュートリアルは、分散マルチホスト TPU スライスで 700 億個のパラメータ モデルをトレーニングする際のメモリとネットワークの課題を克服する方法を学習するプラットフォーム管理者、オペレーター、AI スペシャリストを対象としています。
背景
GKE、KubeRay、MaxText、TPU を組み合わせることで、大規模なモデル トレーニングのための強力でスケーラブルなプラットフォームが実現します。このセクションでは、このガイドで使用されている重要なテクノロジーについて説明します。
JAX
JAX は、アクセラレータ指向の配列計算とプログラム変換のための Python ライブラリです。XLA コンパイラを使用して、アクセラレータで効率的にスケーリングする高度に最適化されたコードを作成します。
MaxText
MaxText は、スケーラビリティとカスタマイズ性を重視して設計された、高パフォーマンスのオープンソース LLM フレームワークです。MaxText は JAX 上に構築されており、Cloud TPU で効率的に実行できるように最適化されています。
TPU
Tensor Processing Unit(TPU)は、機械学習ワークロードを最適化するために Google が作成したカスタム設計のアクセラレータです。汎用 CPU や並列処理 GPU とは異なり、TPU はディープ ラーニングの基盤となる大規模な行列とテンソルの計算に特化しているため、この特定のタスクを効率的に実行できます。TPU の主な利点は、パフォーマンス拡張です。
このチュートリアルでは、マルチスライス デプロイ パターンで第 6 世代 TPU である TPU Trillium を使用します。Cloud TPU マルチスライスは、2 つ以上の Cloud TPU スライスがデータセンター ネットワーク(DCN)上で通信する場所です。マルチスライスは、フルスタックで費用対効果に優れた大規模なトレーニングを可能にします。最大数万の TPU チップまでほぼ線形にスケールアップできます。マルチスライスの詳細については、Cloud TPU マルチスライスの概要をご覧ください。
KubeRay
KubeRay は、Kubernetes で Ray アプリケーションをデプロイ、管理、モニタリングするための統一された方法を提供する Kubernetes オペレーターです。KubeRay オペレーターは、Ray on GKE アドオンを介してインストールおよび管理されます。これは、GKE 上の Ray クラスタをデプロイして管理するおすすめの方法です。
GKE 動的リソース割り当てネットワーク(DRANET)
GKE DRANET(動的リソース割り当てネットワーク)は、高性能ネットワーク デバイスを Pod に動的に接続し、標準の Kubernetes ネットワーキングをバイパスして、DCN で高性能を実現する機能です。
目標
このチュートリアルでは、次の方法を説明します。
- 2 つのマルチホスト TPU ノードプールを使用して GKE クラスタを設定します。
- クロススライス TPU 通信用のセカンダリ DCN を構成します。
- 分散トレーニング環境を管理するように KubeRay を構成します。
- ネットワーク アタッチメントに動的リソース割り当て(DRA)を使用して、RayCluster カスタム リソースをデプロイします。
- Ray Train の JaxTrainer を利用して、TPU スライス全体で MaxText トレーニング ループをオーケストレートする Python トレーニング スクリプトを作成します。
- ベースラインの Llama 3 8B トレーニング ジョブを実行します。
- DCN 上で 2D シャーディング(テンソル並列処理と FSDP)を利用して、Llama 3 70B までスケールアップします。
始める前に
- Google Cloud アカウントにログインします。 Google Cloudを初めて使用する場合は、 アカウントを作成して、実際のシナリオでの Google プロダクトのパフォーマンスを評価してください。新規のお客様には、ワークロードの実行、テスト、デプロイができる無料クレジット $300 分を差し上げます。
-
Google Cloud CLI をインストールします。
-
外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。
-
gcloud CLI を初期化するには、次のコマンドを実行します。
gcloud init -
Google Cloud プロジェクトを作成または選択します。
プロジェクトの選択または作成に必要なロール
- プロジェクトを選択する: プロジェクトの選択に特定の IAM ロールは必要ありません。ロールが付与されているプロジェクトであれば、どのプロジェクトでも選択できます。
-
プロジェクトを作成する: プロジェクトを作成するには、
resourcemanager.projects.create権限を含むプロジェクト作成者ロール(roles/resourcemanager.projectCreator)が必要です。詳しくは、ロールを付与する方法をご覧ください。
-
Google Cloud プロジェクトを作成します。
gcloud projects create PROJECT_ID
PROJECT_IDは、作成する Google Cloud プロジェクトの名前に置き換えます。 -
作成した Google Cloud プロジェクトを選択します。
gcloud config set project PROJECT_ID
PROJECT_IDは、 Google Cloud プロジェクトの名前に置き換えます。
必要な API を有効にします。
API を有効にするために必要なロール
API を有効にするには、
serviceusage.services.enable権限を含む Service Usage 管理者 IAM ロール(roles/serviceusage.serviceUsageAdmin)が必要です。詳しくは、ロールを付与する方法をご覧ください。gcloud services enable container.googleapis.com
cloudbuild.googleapis.com -
Google Cloud CLI をインストールします。
-
外部 ID プロバイダ(IdP)を使用している場合は、まず連携 ID を使用して gcloud CLI にログインする必要があります。
-
gcloud CLI を初期化するには、次のコマンドを実行します。
gcloud init -
Google Cloud プロジェクトを作成または選択します。
プロジェクトの選択または作成に必要なロール
- プロジェクトを選択する: プロジェクトの選択に特定の IAM ロールは必要ありません。ロールが付与されているプロジェクトであれば、どのプロジェクトでも選択できます。
-
プロジェクトを作成する: プロジェクトを作成するには、
resourcemanager.projects.create権限を含むプロジェクト作成者ロール(roles/resourcemanager.projectCreator)が必要です。詳しくは、ロールを付与する方法をご覧ください。
-
Google Cloud プロジェクトを作成します。
gcloud projects create PROJECT_ID
PROJECT_IDは、作成する Google Cloud プロジェクトの名前に置き換えます。 -
作成した Google Cloud プロジェクトを選択します。
gcloud config set project PROJECT_ID
PROJECT_IDは、 Google Cloud プロジェクトの名前に置き換えます。
必要な API を有効にします。
API を有効にするために必要なロール
API を有効にするには、
serviceusage.services.enable権限を含む Service Usage 管理者 IAM ロール(roles/serviceusage.serviceUsageAdmin)が必要です。詳しくは、ロールを付与する方法をご覧ください。gcloud services enable container.googleapis.com
cloudbuild.googleapis.com -
ユーザー アカウントにロールを付与します。次の IAM ロールごとに次のコマンドを 1 回実行します。
roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editorgcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE
次のように置き換えます。
PROJECT_ID: プロジェクト ID。USER_IDENTIFIER: ユーザー アカウントの識別子。例:myemail@example.comROLE: ユーザー アカウントに付与する IAM ロール。
- このチュートリアルでは TPU Trillium(v6e)を利用するため、利用可能なリージョンまたはゾーンを選択します。詳細については、Cloud TPU の割り当てをご覧ください。
環境を準備する
このチュートリアルでは、Cloud Shell を使用します。Cloud Shell には、このチュートリアルで使用する gcloud、helm、kubectl コマンドライン ツールがプリインストールされています。
Google Cloud コンソールに移動します。
Google Cloud コンソール ウィンドウの上部にある [Cloud Shell をアクティブにする]
ボタンをクリックします。Google Cloud コンソールの新しいフレーム内で Cloud Shell セッションが開き、コマンドライン プロンプトが表示されます。
ターミナルで、
kubernetes-engine-samplesリポジトリのクローンを作成します。git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.gitサンプル ファイルが含まれているディレクトリに移動します。
cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtextPython 仮想環境を作成してアクティブにします。
python3 -m venv ray-env source ray-env/bin/activateRay CLI をインストールします。
pip install "ray[default]==2.55.0"次の環境変数を設定します。
export PROJECT_ID=$(gcloud config get project) export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)") export GS_BUCKET=GS_BUCKET export KSA_NAME=KSA_NAME export NAMESPACE=default export CLUSTER_NAME=CLUSTER_NAME export REGION=REGION export ZONE=ZONE export CLUSTER_VERSION=1.35.2-gke.1842000次のように置き換えます。
GS_BUCKET: Cloud Storage バケットの名前。KSA_NAME: Kubernetes サービス アカウントの名前。CLUSTER_NAME: 新しいクラスタの名前。REGION: TPU Trillium の容量が使用可能なリージョン。ZONE: TPU Trillium の容量が使用可能なゾーン。詳細については、GKE での TPU の可用性をご覧ください。
Cloud TPU マルチスライスのクラスタ ネットワーキングを構成する
マルチホスト TPU スライス内では、TPU デバイスは高速チップ間相互接続を介して通信します。ただし、マルチスライス ジョブを実行する場合は、TPU スライスが DCN を介して相互に通信する必要があります。標準の Kubernetes Pod ネットワークでは、このトラフィックがボトルネックになる可能性があります。ct6e-standard-4t マシンタイプは、複数の物理ネットワーク インターフェース カード(NIC)を基盤としています。最高のパフォーマンスを実現するには、2 つの追加の VPC ネットワークを作成し、GKE DRANET を使用して Ray Pod に直接接続します。
大きな最大トレーニング単位(MTU)を使用して、2 つの追加の VPC ネットワークを作成します。
gcloud compute networks create ${CLUSTER_NAME}-net-1 \ --subnet-mode=custom \ --mtu=8896 gcloud compute networks create ${CLUSTER_NAME}-net-2 \ --subnet-mode=custom \ --mtu=8896専用サブネットを作成します。
gcloud compute networks subnets create tpu-subnet-1 \ --network=${CLUSTER_NAME}-net-1 \ --region=${REGION} \ --range=10.50.0.0/16 gcloud compute networks subnets create tpu-subnet-2 \ --network=${CLUSTER_NAME}-net-2 \ --region=${REGION} \ --range=10.60.0.0/16
GKE クラスタを作成する
GKE Autopilot クラスタまたは GKE Standard クラスタの TPU で KubeRay を構成できます。フルマネージドの Kubernetes エクスペリエンスを実現するには、Autopilot クラスタを使用することをおすすめします。ワークロードに最適な GKE の運用モードを選択するには、GKE の運用モードについてをご覧ください。
GKE マネージド DRANET を使用するには、クラスタで Autopilot モードの場合はバージョン 1.35.2-gke.1842000 以降、標準モードの場合は 1.34.1-gke.1829001 以降を使用する必要があります。このチュートリアルでは、バージョン 1.35.2-gke.1842000 を使用します。
Autopilot
Cloud Shell で、次のコマンドを実行します。
gcloud container clusters create-auto $CLUSTER_NAME \ --enable-ray-operator \ --machine-type=n1-standard-16 \ --location=$REGION \ --cluster-version=${CLUSTER_VERSION}クラスタと通信するには、
kubectlを構成します。gcloud container clusters get-credentials CLUSTER_NAME \ --location=$REGION
Standard
Cloud Shell で、次のコマンドを実行して、Ray オペレータ アドオンを有効にする Standard クラスタを作成します。
gcloud container clusters create $CLUSTER_NAME \ --addons=RayOperator,GcsFuseCsiDriver \ --machine-type=n1-standard-16 \ --enable-dataplane-v2 \ --workload-pool=$PROJECT_ID.svc.id.goog \ --location=$ZONE \ --cluster-version=${CLUSTER_VERSION}このコマンドは
GcsFuseCsiDriverも有効にします。これにより、Pod は Cloud Storage バケットをローカル ファイル システムとしてマウントできます。クラスタの作成には数分かかることもあります。クラスタと通信するには、
kubectlを構成します。gcloud container clusters get-credentials CLUSTER_NAME \ --location=$ZONEGKE DRANET を有効にして、最初のマルチホスト TPU スライス ノードプールを作成します。
gcloud container node-pools create v6e-16-0 \ --location=$ZONE \ --cluster=$CLUSTER_NAME \ --machine-type=ct6e-standard-4t \ --threads-per-core=1 \ --tpu-topology=4x4 \ --num-nodes=4 \ --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \ --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \ --node-labels=cloud.google.com/gke-networking-dra-driver=true \ --enable-gvnic \ --scopes=https://www.googleapis.com/auth/cloud-platform2 つ目の TPU スライス ノードプールを作成します。
gcloud container node-pools create v6e-16-1 \ --location=$ZONE \ --cluster=$CLUSTER_NAME \ --machine-type=ct6e-standard-4t \ --threads-per-core=1 \ --tpu-topology=4x4 \ --num-nodes=4 \ --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \ --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \ --node-labels=cloud.google.com/gke-networking-dra-driver=true \ --enable-gvnic \ --scopes=https://www.googleapis.com/auth/cloud-platform
GKE は、4 つの TPU Trillium(v6e)VM で構成されるノードプールをプロビジョニングします。これらは、4x4 トポロジを持つマルチホスト TPU スライスとして構成されます。このノードプールは、分散トレーニング ワークロードの準備ができています。
Ray オペレーターが有効になっている GKE クラスタは、クラスタに KubeRay と KubeRay TPU Webhook を自動的にインストールします。
Cloud Storage バケットとサービス アカウントを構成する
マルチホスト TPU ノード間で共有されるチェックポイント用の Cloud Storage バケットを作成します。
gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}Cloud Storage バケットへのアクセスを有効にするには、Kubernetes サービス アカウントを作成します。
kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}Cloud Storage バケットへのアクセスを有効にするには、必要な IAM ポリシー バインディングをサービス アカウントに追加します。
gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \ --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \ --role "roles/storage.objectUser"
トレーニング スクリプトを作成する
maxtext_multi_slice_trainer.py スクリプトは、Ray Train の JaxTrainer を使用して、2 つの TPU スライスで分散 MaxText トレーニング ジョブを実行します。このスクリプトは、8 つのマルチホスト TPU ワーカーのトレーニング環境を構成し、各ワーカーノードで MaxText トレーニング ジョブを実行します。train_loop_per_worker 関数は MaxText のメイン エントリ ポイントをラップし、Ray の分散スケジューラを使用してマルチホスト TPU スライスで MaxText トレーナーを実行します。
上記のスクリプトは、8 つのワーカーと 4x4 のトポロジをリクエストする JaxTrainer インスタンスを定義します。内部的には、Ray は 2 つの TPU スライスに SlicePlacementGroup をプロビジョニングし、Ray Train ワーカーが両方のスライスでアトミックに実行されるようにします(ホストごとに 1 つのワーカー)。
モデルのトレーニング
現在のディレクトリの
ray-cluster.tpu-multi-slice.yamlマニフェストは、RayCluster カスタム リソースを定義します。このマニフェストには、GKE DRANET と Multislice のネットワーク デバイスをプロビジョニングする DRANETResourceClaimTemplateが含まれています。上記の RayCluster 仕様では、レプリカごとに 8 つのワーカー(
numOfHosts: 4)を含む TPU ワーカー グループを 2 つのレプリカで作成します。各ワーカーは 4 つの TPU チップ(google.com/tpu: "4")をリクエストします。ワーカーはそれぞれ、同じコロケーションされたマルチホスト スライスの一部である TPU Trillium ノード(tpu-v6e-slice)でスケジュールされます。KubeRay は、スライス内の 4 つのワーカーすべてをアトミックにスケーリングします。必要な JAX 環境変数とスケジューリング用の Pod アフィニティは、変更用 Webhook を介して GKE によってブートストラップされます。RayCluster を作成するには、次のマニフェストを適用します。
envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -クラスタが使用できるようになり、実行中であることを確認します。
kubectl get rayclusters maxtext-tpu-cluster出力例を以下に示します。
NAME DESIRED WORKERS AVAILABLE WORKERS CPUS MEMORY GPUS STATUS AGE maxtext-tpu-cluster 8 8 72 1579277216Ki 0 ready 2m11sRay ヘッドサービスを介して Ray ダッシュボードにアクセスするには、ポート転送セッションを確立します。
kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &ローカル環境から RayCluster にアクセスできることを確認します。
ray list nodes --address http://localhost:8265出力例を以下に示します。
ray list nodes --address http://localhost:8265 2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8. 2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads. ======== List: 2026-04-21 10:20:20.945431 ======== Stats: ------------------------------ Total: 9 Table: ------------------------------ NODE_ID NODE_IP IS_HEAD_NODE STATE STATE_MESSAGE NODE_NAME RESOURCES_TOTAL LABELS 0 4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1 10.68.9.5 False ALIVE 10.68.9.5 CPU: 8.0 ray.io/accelerator-type: TPU-V6E TPU: 4.0 ray.io/node-group: tpu-group accelerator_type:TPU-V6E: 1.0 ray.io/node-id: 4f0e4d742... memory: 186.265 GiB ray.io/tpu-pod-type: v6e-16 node:10.68.9.5: 1.0 ray.io/tpu-slice-name: tpu-group-0 object_store_memory: 186.265 GiB ray.io/tpu-topology: 4x4 tpu-group-0: 1.0 ray.io/tpu-worker-id: '1' ... 6 ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938 10.68.2.9 True ALIVE 10.68.2.9 CPU: 8.0 ray.io/node-group: headgroup memory: 16.000 GiB ray.io/node-id: ce7056807... node:10.68.2.9: 1.0 node:__internal_head__: 1.0 object_store_memory: 4.765 GiB ...基本の MaxText 構成ファイルをダウンロードします。このファイルは、モデルのデフォルトのハイパーパラメータを設定するためにトレーニング スクリプトで必要です。
curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.ymlJaxTrainer スクリプトを RayCluster に送信し、RayJob が正常に完了することを確認します。
Llama 3 8B
ray job submit \
--address http://localhost:8265 \
--working-dir . \
--runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
-- python maxtext_multi_slice_trainer.py \
base.yml \
base_output_directory=/data/ \
dataset_type=synthetic \
per_device_batch_size=4 \
max_target_length=4096 \
model_name=llama3-8b \
steps=100 \
ici_fsdp_parallelism=4 \
ici_tensor_parallelism=4 \
run_name=rayjob-multi-slice
Llama 3 70B
ray job submit \
--address http://localhost:8265 \
--working-dir . \
--runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
-- python maxtext_multi_slice_trainer.py \
base.yml \
base_output_directory=/data/ \
dataset_type=synthetic \
per_device_batch_size=2 \
max_target_length=4096 \
model_name=llama3-70b \
steps=100 \
ici_tensor_parallelism=4 \
ici_fsdp_parallelism=4 \
dcn_fsdp_parallelism=2 \
dcn_data_parallelism=1 \
remat_policy=full \
run_name=rayjob-multi-slice-70b-fsdp
上記のコマンドは、JaxTrainer Ray コードを呼び出す Python スクリプトを RayCluster に送信します。ray job submit コマンドには、モデル構成に渡す MaxText に固有の引数が含まれています。
ターミナルに、Llama 3 70B ジョブの出力が表示されます。
[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]
------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------
Spot VM でマルチスライス エラスティック トレーニングを実行する
TPU などの需要の高いアクセラレータを使用する場合は、Spot VM を利用することでコストを大幅に削減できます。ただし、Spot VM は予期せずプリエンプトされることがあります。
Ray Train はエラスティック トレーニングをサポートしています。これにより、ジョブは参加している TPU スライスの数を動的にスケーリングできます。スライスがプリエンプトされると、Ray はトレーニング ループを一時停止し、残りのワーカーが再編成されるのを待ってから、最新の MaxText チェックポイントから復元し、フットプリントを小さくしてトレーニングを再開します。
エラスティック トレーニングを有効にするには、ScalingConfig の num_workers パラメータを静的整数から (minimum_workers, maximum_workers) を表すタプルに変更します。また、RunConfig に FailureConfig(max_failures=3) を追加します。これにより、ワーカーがプリエンプトされたときにジョブ全体を失敗させるのではなく、トレーニング ループを最大 3 回再試行するように Ray Train に指示します。
Ray Train スクリプトを更新する
現在のディレクトリにある
maxtext_elastic_trainer.pyスクリプトにより、エラスティック トレーニングが有効になります。num_workers=(4,8)が設定されていることに注意してください。これは、16 チップ スライス(4 つのワーカー)が 1 つ以上使用可能な場合は Ray に続行するよう指示し、可能であれば 2 つのスライス(8 つのワーカー)にスケールアップするよう指示します。これには、エラスティック トレーニングを有効にし、再試行回数を定義し、ジョブがプリエンプションを回避できるようにするFailureConfigが含まれています。Ray Job CLI を使用してジョブを送信します。チェックポイントが以前の実行と競合しないように、一意の
run_nameを指定してください。ray job submit \ --address http://localhost:8265 \ --working-dir . \ --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \ -- python maxtext_elastic_trainer.py \ base.yml \ base_output_directory=/data/ \ dataset_type=synthetic \ per_device_batch_size=4 \ max_target_length=4096 \ model_name=llama3-8b \ steps=100 \ ici_fsdp_parallelism=4 \ ici_tensor_parallelism=4 \ run_name=rayjob-elastic-8bトレーニング中にノードの終了またはプリエンプションをシミュレートするには、Pod を削除します。
kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
ターミナルにはワーカーの障害が記録されますが、オーケストレーション コントローラはジョブを存続させ、最小限のトポロジが使用可能になった後、/data/rayjob-elastic-8b/checkpoints チェックポイントから自動的に再開します。
MaxText は再開時にデバイス メッシュを動的に再計算するため、トポロジが縮小したときにチェックポイントの再シャーディングを処理するカスタム ロジックを記述する必要はありません。JAX の Orbax チェックポイントは、トレーニング ループを続行する前に、保存された重みを新しい物理レイアウトに自動的に再シャーディングします。次の出力は、Ray Train コントローラがクラスタで新しく使用可能な TPU リソースを検出し、トレーニング中に 1 つのスライス(4 つのワーカー)から 2 つのスライス(8 つのワーカー)にスケーリング オペレーションを実行することを示しています。
...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048 20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8
クリーンアップ
このチュートリアルで使用したリソースについて、 Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。
RayCluster を削除します。
kubectl delete raycluster maxtext-tpu-clusterGKE クラスタを削除します。
gcloud container clusters delete $CLUSTER_NAME --zone=$ZONECloud Storage バケットを削除します。
gsutil rm -r gs://${GS_BUCKET}
次のステップ
- Ray on Kubernetes について学習する。
- GKE で TPU を使用して vLLM をサービングする方法を確認する。
- GKE の TPU の詳細を確認する。