TPU モニタリング ライブラリ

基盤ソフトウェア レイヤの LibTPU に直接構築された高度な TPU モニタリング機能を使用して、Cloud TPU ハードウェアのパフォーマンスや動作に関する詳細な分析情報を取得できます。LibTPU には、TPU とのやり取りに使用されるドライバ、ネットワーキング ライブラリ、XLA コンパイラ、TPU ランタイムが含まれますが、このドキュメントでは TPU モニタリング ライブラリに焦点を当てます。

TPU モニタリング ライブラリは次の機能を提供します。

  • 包括的なオブザーバビリティ: Telemetry API と指標スイートにアクセスできます。これにより、TPU の運用パフォーマンスや特定の動作に関する詳細な分析情報が得られます。

  • 診断ツールキット: TPU リソースのデバッグや詳細なパフォーマンス分析を可能にするために設計された SDK とコマンドライン インターフェース(CLI)を提供します。

これらのモニタリング機能は、トップレベルのお客様向けソリューションとして設計されており、TPU ワークロードを効果的に最適化するために不可欠なツールとなります。

TPU モニタリング ライブラリを使用すると、ML ワークロードが TPU ハードウェア上でどのように動作しているかを示す詳細な情報が得られます。これは、TPU 使用率の把握、ボトルネックの特定、パフォーマンスに関する問題のデバッグに役立つよう設計されています。得られる情報は、中断指標、グッドプット指標、その他の指標よりも詳細です。

TPU モニタリング ライブラリを使ってみる

これらの有益な分析情報にアクセスするのは簡単です。TPU モニタリング機能は LibTPU SDK と統合されているため、LibTPU をインストールするとこの機能もインストールされます。

LibTPU をインストールする

pip install libtpu

また、LibTPU のアップデートは JAX のリリースと連携しています。つまり、最新の JAX リリース(月 1 回リリースされる)をインストールすると、通常は互換性のある最新の LibTPU バージョンとその機能がご使用のシステムに紐づけされます。

JAX をインストールする

pip install -U "jax[tpu]"

PyTorch ユーザーの場合は、PyTorch/XLA をインストールすると、最新の LibTPU と TPU モニタリング機能が提供されます。

PyTorch / XLA をインストールする

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

PyTorch/XLA のインストールの詳細については、PyTorch/XLA GitHub リポジトリのインストールをご覧ください。

Python でライブラリをインポートする

TPU Monitoring ライブラリの使用を開始するには、Python コードに libtpu モジュールをインポートする必要があります。

from libtpu.sdk import tpumonitoring

サポートされているすべての機能を一覧表示する

すべての指標名と、それらの指標でサポートされている機能を一覧表示します。


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

サポートされている指標

次のコードサンプルは、サポートされているすべての指標名を一覧表示する方法を示しています。

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

次の表に、すべての指標とその定義を示します。

指標 定義 API の指標名 値の例
Tensor Core 使用率 TensorCore の使用率を測定します。これは、TensorCore 演算の一部である演算の割合として計算されます。1 秒ごとに 10 マイクロ秒のデータがサンプリングされます。サンプリング レートを変更することはできません。この指標により、TPU デバイスでのワークロードの効率をモニタリングできます。 tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# アクセラレータ ID 0~3 の使用率(%)
デューティ サイクルの割合 過去のサンプル期間(5 秒ごと。LIBTPU_INIT_ARG フラグによって調整可能)中にアクセラレータがアクティブに処理していた時間の割合(最後のサンプリング期間中に HLO プログラムの実行に使用されていたサイクルとともに記録される)。この指標は、TPU がどれだけビジーであるかを表します。この指標はチップごとに出力されます。 duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# アクセラレータ ID 0~3 のデューティ サイクルの割合(%)
HBM 容量の合計 この指標は、HBM の合計容量をバイト単位で報告します。 hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# アクセラレータ ID 0~3 にアタッチされている HBM 容量の合計(バイト単位)。
HBM 容量の使用量 この指標は、過去のサンプル期間(5 秒ごと。LIBTPU_INIT_ARG フラグによって調整可能)中の HBM 容量の使用量をバイト単位で報告します。 hbm_capacity_usage ['100', '200', '300', '400']

# アクセラレータ ID 0~3 にアタッチされている HBM 容量の使用量(バイト単位)。
バッファ転送レイテンシ 大規模なマルチスライス トラフィックのネットワーク転送レイテンシ。この可視化により、全体的なネットワーク パフォーマンス環境を把握できます。 buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# ネットワーク転送レイテンシ分布のバッファサイズ、平均、p50、p90、p99、p99.9。
高レベル演算の実行時間分布指標 HLO コンパイル済みバイナリの実行ステータスに関する詳細なパフォーマンス分析情報を提供します。これにより、回帰検出とモデルレベルのデバッグが可能になります。 hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# CoreType-CoreID の HLO 実行時間の分布(平均、p50、p90、p95、p999)。
High Level Optimizer のキューサイズ HLO 実行キューサイズのモニタリングは、待機中または実行中のコンパイル済み HLO プログラムの数を追跡します。この指標は、実行パイプラインの輻輳を明らかにし、ハードウェア実行、ドライバ オーバーヘッド、リソース割り当てにおけるパフォーマンス ボトルネックの特定に役立ちます。 hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# CoreType-CoreID のキューサイズを測定します。
エンドツーエンドのグループ レイテンシ この指標は、オペレーションを開始したホストから、出力を受信したすべてのピアへの DCN を介したエンドツーエンドのグループ レイテンシをマイクロ秒単位で測定します。これには、ホスト側のデータ削減と TPU への出力の送信が含まれます。結果は、バッファサイズ、タイプ、平均、p50、p90、p95、p99.9 のレイテンシを詳細に説明する文字列です。 collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# 転送サイズ - エンドツーエンドのグループ レイテンシのグループ オペレーション、平均、p50、p90、p95、p999。
トランスポート レイヤでのラウンドトリップ レイテンシ gRPC がマルチスライス TPU トラフィックに使用する TCP 接続で観測された最小ラウンドトリップ時間(RTT)の分布。 grpc_tcp_min_round_trip_times ['27.63, 29.03, 38.52, 41.63, 52.74']

# 分布の平均、p50、p90、p95、p99.9 パーセンタイルをマイクロ秒(µs)で表します。
トランスポート レイヤのスループット マルチスライス TPU トラフィックの gRPC で使用される TCP 接続の最近のスループットの累積分布。 grpc_tcp_delivery_rates ['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']

# 分布の平均、p50、p90、p95、p99.9 パーセンタイルをマイクロ秒(µs)で表します。

指標データを読み取る

指標データを読み取るには、tpumonitoring.get_metric 関数を呼び出すときに指標名を指定します。パフォーマンスの悪いコードにアドホック指標チェックを挿入することで、パフォーマンスの問題がソフトウェアとハードウェアのどちらに起因するのかを特定できます。

次のコードサンプルは、duty_cycle 指標を読み取る方法を示しています。

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

指標を使用して TPU 使用率を確認する

次の例は、TPU Monitoring ライブラリの指標を使用して TPU 使用率を追跡する方法を示しています。

JAX トレーニング中の TPU デューティ サイクルをモニタリングする

シナリオ: JAX トレーニング スクリプトを実行し、トレーニング プロセスの間中 TPU の duty_cycle_pct 指標をモニタリングして TPU が効果的に使用されていることを確認します。トレーニング中にこの指標を定期的にログに出力することで、TPU 使用率を追跡できます。

次のコードサンプルは、JAX トレーニング中に TPU のデューティ サイクルをモニタリングする方法を示しています。

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

JAX 推論を実行する前に HBM 使用率を確認する

シナリオ: JAX モデルで推論を実行する前に、TPU の現在の HBM(高帯域幅メモリ)使用率をチェックして十分な量のメモリが使用可能であることを確認します。また、推論の開始前にベースライン測定値を取得します。

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

ネットワーク指標を確認する

シナリオ: マルチホストとマルチスライス ワークロードを実行しており、ワークロードの実行中にネットワーク指標を表示するために、ssh を使用して GKE Pod または TPU のいずれかに接続したい。コマンドは、マルチホスト ワークロードに直接組み込むこともできます。

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

 # --- Your JAX model and training setup goes here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)
     # --- details here ---

    # ==============================================================================
    # Metric 1: TCP Delivery Rate
    # ==============================================================================
    # This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)

        # Get the metric object
        delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)

        # Print the description provided by the library
        print("Description:", delivery_rate_metric.description())

        # Print the actual data payload
        print("Data:", delivery_rate_metric.data())

    # ==============================================================================
    # Metric 2: TCP Minimum Round Trip Time (RTT)
    # ==============================================================================
    # This metric reflects the minimum RTT measured between sending a TCP packet
    # and receiving the acknowledgement.
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)

        # Get the metric object
        min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)

        # Print the description provided by the library
        print("Description:", min_rtt_metric.description())

        # Print the actual data payload
        print("Data:", min_rtt_metric.data())

TPU 指標の更新頻度

TPU 指標の更新頻度は、最小 1 秒に 1 回に制限されています。ホスト指標データは 1 Hz の固定頻度でエクスポートされます。このエクスポート プロセスによって発生するレイテンシはごくわずかです。LibTPU のランタイム指標には、同じ頻度の制約は適用されません。ただし、一貫性を保つため、これらの指標も 1 Hz(1 秒あたり 1 サンプル)でサンプリングされます。

TPU-Z モジュール

TPU-Z は、TPU のテレメトリーとデバッグ機能です。ホストに接続されているすべての TPU コアの詳細なランタイム ステータス情報を提供します。この機能は、libtpu Python SDK の libtpu.sdk モジュールの一部である tpuz モジュールを通じて提供されます。このモジュールは、各コアの状態のスナップショットを提供します。

TPU-Z の主なユースケースは、分散 TPU ワークロードのハングやデッドロックを診断することです。ホストで TPU-Z サービスをクエリして、すべてのコアの状態をキャプチャし、すべてのコアのプログラム カウンタ、HLO の場所、実行 ID を比較して、異常を特定できます。

libtpu.sdk ライブラリ内の get_core_state_summary() 関数を使用して、TPU-Z 指標を表示します。

summary = sdk.tpuz.get_core_state_summary()

TPU-Z 指標の出力はディクショナリとして提供されます。単一コアの短縮された例を次に示します。

{
  "host_name": "my-tpu-host-vm",
  "core_states": {
    "1": {
      "core_id": {
        "global_core_id": 1,
        "chip_id": 0,
        "core_on_chip": {
          "type": "TPU_CORE_TYPE_TENSOR_CORE",
          "index": 1
        }
      },
      "sequencer_info": [
        {
          "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
          "sequencer_index": 0,
          "pc": 4490,
          "program_id": 3274167277388825310,
          "run_id": 3
        }
      ],
      "program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
      "queued_program_info": [],
      "error_message": ""
    }
    // ...
  }
}

各コアの High-Level Optimizer(HLO)に関する情報を取得するには、include_hlo_info パラメータを True に設定します。

summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)

出力には、追加の HLO 情報が含まれます。

"1": {
    "core_id": {
      "global_core_id": 1,
      "chip_id": 0,
      "core_on_chip": {
        "type": "TPU_CORE_TYPE_TENSOR_CORE",
        "index": 1
      }
    },
    "sequencer_info": [
      {
        "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
        "sequencer_index": 0,
        "pc": 17776,
        "tag": 3,
        "tracemark": 2147483646,
        "program_id": 3230481660274331500,
        "run_id": 81,
        "hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
        "hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
      }
    ],
    "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
    "launch_id": 1394130914,
    "queued_program_info": [
      {
        "run_id": 81,
        "launch_id": 1394130914,
        "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
      }
    ]
  }

TPU-Z 指標

get_core_state_summary 関数は、次の構造のディクショナリ形式で TPU-Z 指標を返します。

CurrentCoreStateSummary

CurrentCoreStateSummary 辞書には、個々の TPU コアの状態の詳細な概要が示されます。

フィールド 説明
core_id 辞書 TPU コアの ID 情報を含む TpuCoreIdentifier 辞書。
sequencer_info 辞書のリスト コア上の各シーケンサーの状態を記述する SequencerInfo 辞書のリスト。
program_fingerprint バイト このコアで実行されているプログラムのフィンガープリント。
launch_id 整数 現在または最新のプログラムの起動 ID。
queued_program_info 辞書のリスト 実行待ちのプログラムの QueuedProgramInfo 辞書のリスト。
error_message 文字列 このコアのエラー メッセージ。

TpuCoreIdentifier

TpuCoreIdentifier 辞書は、TPU システム内のコアの ID 情報を提供します。

フィールド 説明
global_core_id 整数 コアの ID。
chip_id 整数 コアが属するチップの ID。
core_on_chip 辞書 コアのタイプとチップ上のインデックスを記述する TpuCoreOnChip 辞書。

TpuCoreOnChip

TpuCoreOnChip 辞書には、特定のチップ内のコアのプロパティに関する情報が含まれています。

フィールド 説明
type 文字列 TPU コアのタイプ。例: TPU_CORE_TYPE_TENSOR_CORE
index 整数 チップ上のコアのインデックス。

SequencerInfo

SequencerInfo 辞書には、コア上の単一のシーケンサーの状態に関する情報が含まれています。

フィールド 説明
sequencer_type 文字列 シーケンサーのタイプ。 例: TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER
sequencer_index 整数 シーケンサーのインデックス(同じタイプのシーケンサーが複数ある場合)。
pc 整数 現在のプログラム カウンタの値。
program_id 整数 TPU コアで実行するために起動されるプログラムの特定のインスタンスに関連付けられた ID。
run_id 整数 TPU コアでのプログラム実行の特定のインスタンスに関連付けられた実行 ID。
hlo_location 文字列 High Level Optimizer の位置情報。
hlo_detailed_info 文字列 High Level Optimizer の詳細情報。

QueuedProgramInfo

QueuedProgramInfo 辞書には、コアで実行するためにキューに登録されたプログラムに関する情報が含まれます。

フィールド 説明
run_id 整数 キューに登録されたプログラムの実行 ID。
launch_id 整数 キューに登録されたプログラムの起動 ID。
program_fingerprint バイト キューに登録されたプログラムのフィンガープリント。

JAX の TPU-Z

JAX ワークロードで TPU-Z 指標にアクセスするには、libtpu.sdk ライブラリを使用します。次の Python スクリプトでは、高パフォーマンスのテンソル計算に JAX を使用すると同時に、バックグラウンド スレッドで libtpu SDK を使用して、基盤となる TPU ハードウェアの状態とアクティビティをモニタリングします。

次の Python パッケージを含めます。

import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk

monitor_tpu_status 関数は、バックグラウンド スレッドを使用して、メイン アプリケーションが JAX ワークロードを実行している間、TPU コアの動作ステータスを継続的に表示します。これは、リアルタイムの診断ツールとして機能します。

def monitor_tpu_status():
  """Monitors TPU status in a background thread."""

  while monitoring_active:
    try:
      summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
      if summary and 'core_states' in summary:
        print(summary)
      else:
        print('WARNING: Call returned an empty or invalid summary.')
    except RuntimeError as e:
      print(f'FAIL: Error calling API: {e}')
    except Exception as e:
      print(f'FAIL: Unexpected error in monitor thread: {e}')

    for _ in range(MONITORING_INTERVAL_SECONDS * 2):
      if not monitoring_active:
        break
      time.sleep(0.5)
  print('✅ Monitoring thread stopped.')

transformer_block 関数は、LLM の基本的なビルディング ブロックである Transformer アーキテクチャの完全なレイヤを実装します。

@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
  """A simplified but computationally intensive Transformer block."""
  # Multi-head Self-Attention
  qkv = jnp.dot(x, params['qkv_kernel'])
  q, k, v = jnp.array_split(qkv, 3, axis=-1)

  # Reshape for multi-head attention
  q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)

  # Scaled dot-product attention
  attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
  attention_weights = jax.nn.softmax(attention_scores, axis=-1)
  attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
  attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)

  attention_output = jnp.dot(attention_output, params['o_kernel'])

  # Residual connection and Layer Normalization 1
  h1 = x + attention_output
  h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
  h1_norm = h1_norm / jnp.sqrt(
      jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
  )

  # Feed-Forward Network
  ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
  ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])

  # Residual connection and Layer Normalization 2
  h2 = h1_norm + ffn_output
  h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
  h2_norm = h2_norm / jnp.sqrt(
      jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
  )

  return h2_norm

main 関数は、JAX 計算の設定をオーケストレートし、バックグラウンド TPU モニタリングを開始して、メイン ワークロード ループを実行します。

def main():
  num_devices = jax.device_count()
  print(f"Running on {num_devices} devices.")

  batch_size = 128 * num_devices
  seq_len = 512
  embed_dim = 1024
  ffn_dim = embed_dim * 4

  key = jax.random.PRNGKey(0)

  params = {
      'qkv_kernel': jax.random.normal(
          key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
      ),
      'o_kernel': jax.random.normal(
          key, (embed_dim, embed_dim), dtype=jnp.bfloat16
      ),
      'ffn1_kernel': jax.random.normal(
          key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
      ),
      'ffn2_kernel': jax.random.normal(
          key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
      ),
  }
  input_data = jax.random.normal(
      key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
  )
  input_data = jax.device_put(input_data)
  monitor_thread = threading.Thread(target=monitor_tpu_status)
  monitor_thread.start()
  print("Starting JAX computation loop...")
  start_time = time.time()
  iterations = 0
  while time.time() - start_time < JOB_DURATION_SECONDS:
    result = transformer_block(params, input_data)
    result.block_until_ready()
    iterations += 1
    print(f'  -> Jax iteration {iterations} complete.', end='\r')

  print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")

  global monitoring_active
  monitoring_active = False
  monitor_thread.join()

if __name__ == '__main__':
  main()

トラブルシューティング

このセクションでは、TPU モニタリング ライブラリの使用中に発生する可能性のある問題の特定と解決に役立つトラブルシューティング情報について説明します。

機能または指標が欠落している

一部の機能や指標が表示されない場合、最も一般的な原因は libtpu のバージョンが古いことです。TPU モニタリング ライブラリの機能と指標は libtpu リリースに含まれています。古いバージョンでは、新しい機能と指標が欠落している可能性があります。

環境で実行されている libtpu のバージョンを確認します。

コマンドライン:

pip show libtpu

Python:

import libtpu

print(libtpu.__version__)

libtpu最新バージョンを使用していない場合は、次のコマンドを使用してライブラリを更新します。

pip install --upgrade libtpu