JAX - TPU のトラブルシューティング

このガイドでは、Cloud TPU で JAX モデルをトレーニングする際に発生する可能性のある問題の特定と解決に役立つ、JAX のトラブルシューティング情報について説明します。

Cloud TPU を使い始める際の一般的なガイドについては、JAX クイックスタートをご覧ください。

JAX に関する一般的な問題

トレーニング モデルの開発中や JAX でのトレーニング中に問題が発生した場合は、JAX のよくある質問をご覧ください。

JAX を使用してトレーニング アプリケーションを作成するときに発生する可能性のある一般的なプログラミング エラーについては、JAX エラーをご覧ください。

JAX パフォーマンスをプロファイリングする

JAX パフォーマンスのプロファイリングで説明されているツールを使用して、TPU リソースの使用状況を把握できます。

メモリの問題のトラブルシューティング

JAX デバイス メモリ プロファイラでメモリの使用状況をモニタリングできますが、その使用状況について直接管理することはできません。

JAX デバイス メモリ プロファイラを使用すると、次のことができます。

TPU メモリを特定のオペレーションに割り当てる方法は指定できません。TPU パフォーマンスに関する JAX 固有の問題の詳細については、JAX で TPU を使用する場合のパフォーマンスに関する注をご覧ください。

TPU の問題のトラブルシューティング

以降のセクションでは、TPU で JAX プログラムを実行する際に発生する可能性のある一般的な問題の解決方法について説明します。

TPU が実行されていることを確認する方法

JAX から「No GPU/TPU found, falling back to CPU.」と出力されない限り、すべてが TPU で実行されます。

TPU がアクティブであることを確認するには、jax.devices() で複数の TPU デバイスが表示されていることを確認するか、assert jax.devices()[0].platform == 'tpu' を使用してプログラムで確認します。

RuntimeError: Unable to initialize backend 'tpu': UNAVAILABLE: No TPU Platform available.

このランタイム エラー メッセージや、TPU VM の /tmp/tpu_logs/tpu_driver.WARNING で次の W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx が見つかった場合は、間違った TPU VM バージョンを実行している可能性があります。

現行の JAX ランタイム バージョンを実行していることを確認し、再試行します。

TPU と GKE の問題のトラブルシューティング

トラブルシューティングに役立つように、GKE ワークロード マニフェストで詳細ログを有効にしてから、ログを GKE サポートに提供します。

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

以降のセクションでは、TPU と GKE の設定に関連するエラー メッセージとその解決方法について説明します。

サービス「jobset-webhook-service」に使用できるエンドポイントがない

このエラーは、ジョブセットが正しくインストールされていないことを意味します。jobset-controller-manager Deployment Kubernetes Pod が実行されているかどうかを確認します。詳細については、JobSet のトラブルシューティングに関するドキュメントをご覧ください。

TPU の初期化に失敗しました: 接続できませんでした

GKE ノード バージョンが 1.30.4-gke.1348000 以降であることを確認します(GKE 1.31 はサポートされていません)。