TPU 監控程式庫

透過進階 TPU 監控功能,深入瞭解 Cloud TPU 硬體的效能和行為,這些功能直接建構在基礎軟體層 LibTPU 上。LibTPU 包含驅動程式、網路程式庫、XLA 編譯器和 TPU 執行階段,可用於與 TPU 互動,但本文重點是 TPU 監控程式庫。

TPU 監控程式庫提供以下功能:

  • 全面監控:存取遙測 API 和指標套件,深入瞭解 TPU 的運作效能和特定行為。

  • 診斷工具包:提供 SDK 和指令列介面 (CLI),可對 TPU 資源進行偵錯及深入分析效能。

這些監控功能是專為客戶設計的頂層解決方案,可提供必要工具,協助您有效最佳化 TPU 工作負載。

TPU 監控程式庫會提供詳細資訊,說明機器學習工作負載在 TPU 硬體上的執行情況。這項工具可協助您瞭解 TPU 使用率、找出瓶頸,以及排解效能問題。這項指標提供的詳細資訊比中斷指標、有效輸送量指標和其他指標更豐富。

開始使用 TPU 監控程式庫

取得這些實用洞察資料的方式很簡單。TPU 監控功能已整合至 LibTPU SDK,因此安裝 LibTPU 時會一併安裝這項功能。

安裝 LibTPU

pip install libtpu

此外,LibTPU 更新會與 JAX 版本同步,也就是說,安裝最新 JAX 版本 (每月發布) 時,通常會將您固定在最新相容的 LibTPU 版本及其功能。

安裝 JAX

pip install -U "jax[tpu]"

PyTorch 使用者安裝 PyTorch/XLA 後,即可取得最新的 LibTPU 和 TPU 監控功能。

安裝 PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

如要進一步瞭解如何安裝 PyTorch/XLA,請參閱 PyTorch/XLA GitHub 存放區中的「安裝」一節。

在 Python 中匯入程式庫

如要開始使用 TPU 監控程式庫,您需要在 Python 程式碼中匯入 libtpu 模組。

from libtpu.sdk import tpumonitoring

列出所有支援的功能

列出所有指標名稱和支援的功能:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

支援的指標

下列程式碼範例說明如何列出所有支援的指標名稱:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

下表列出所有指標及其對應定義:

指標 定義 API 的指標名稱 範例值
Tensor Core 使用率 測量 TensorCore 的用量百分比,計算方式為 TensorCore 作業的占比。每 1 秒取樣 10 微秒。取樣率無法修改。 您可以透過這項指標,監控 TPU 裝置中的工作負載效率。 tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# utilization percentage for accelerator ID 0-3.
工作週期百分比 過往取樣期間 (每 5 秒;可透過設定 LIBTPU_INIT_ARG 標記調整) 內,加速器積極處理作業的時間占比 (按照上一個取樣期間內執行 HLO 程式的週期記錄)。這項指標代表 TPU 的忙碌程度,並會針對每個晶片發出。 duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Duty cycle percentage for accelerator ID 0-3.
HBM 容量總計 這項指標會以位元組為單位,回報 HBM 總容量。 hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Total HBM capacity in bytes that attached to accelerator ID 0-3.
HBM 容量用量 這項指標會回報過去取樣期間 (每 5 秒;可透過設定 LIBTPU_INIT_ARG 旗標調整) 的 HBM 容量用量 (以位元組為單位)。 hbm_capacity_usage ['100', '200', '300', '400']

# Capacity usage for HBM in bytes that attached to accelerator ID 0-3.
緩衝區傳輸延遲 巨型多切片流量的網路傳輸延遲。 這項視覺化功能可協助您瞭解整體網路效能環境。 buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution.
高階作業執行時間分配指標 提供 HLO 編譯二進位檔執行狀態的精細效能洞察資料,可偵測迴歸並進行模型層級的偵錯。 hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999.
高階最佳化工具佇列大小 HLO 執行佇列大小監控會追蹤等待或正在執行的已編譯 HLO 程式數量。這項指標會顯示執行管道壅塞情形,有助於找出硬體執行、驅動程式負擔或資源分配方面的效能瓶頸。 hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Measures queue size for CoreType-CoreID.
集體端對端延遲時間 這項指標會測量 DCN 的端對端集體延遲時間 (以微秒為單位),從主機啟動作業到所有對等互連接收輸出內容。包括主機端資料縮減,以及將輸出內容傳送至 TPU。結果是字串,詳細說明緩衝區大小、類型,以及平均、第 50、90、95 和 99.9 個百分位數的延遲時間。 collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency.
傳輸層的往返延遲時間 gRPC 用於多切片 TPU 流量的 TCP 連線所觀察到的最短往返時間 (RTT) 分佈。 grpc_tcp_min_round_trip_times ['27.63, 29.03, 38.52, 41.63, 52.74']

# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs).
傳輸層的處理量 gRPC 用於多重切片 TPU 流量的 TCP 連線,其近期總處理量的累積分布。 grpc_tcp_delivery_rates ['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']

# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs).

讀取指標資料

如要讀取指標資料,請在呼叫 tpumonitoring.get_metric 函式時指定指標名稱。您可以在效能不佳的程式碼中插入臨時指標檢查,判斷效能問題是源自軟體還是硬體。

以下程式碼範例說明如何讀取 duty_cycle 指標:

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

使用指標檢查 TPU 使用率

下列範例說明如何使用 TPU 監控程式庫中的指標,追蹤 TPU 使用率。

在 JAX 訓練期間監控 TPU 任務週期

情境:您正在執行 JAX 訓練指令碼,並想在整個訓練過程中監控 TPU 的 duty_cycle_pct 指標,確認 TPU 得到有效運用。您可以在訓練期間定期記錄這項指標,追蹤 TPU 使用率。

下列程式碼範例說明如何在 JAX 訓練期間監控 TPU 負載週期:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

執行 JAX 推論前,請先檢查 HBM 使用率

情境: 使用 JAX 模型執行推論前,請先檢查 TPU 的 HBM (高頻寬記憶體) 使用率,確認有足夠的可用記憶體,並在推論開始前取得基準測量結果。

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

查看網路指標

情境: 您正在執行多主機和多配量工作負載,並想使用 SSH 連線至其中一個 GKE Pod 或 TPU,以便在工作負載執行時查看網路指標。這些指令也可以直接併入多主機工作負載。

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

 # --- Your JAX model and training setup goes here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)
     # --- details here ---

    # ==============================================================================
    # Metric 1: TCP Delivery Rate
    # ==============================================================================
    # This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)

        # Get the metric object
        delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)

        # Print the description provided by the library
        print("Description:", delivery_rate_metric.description())

        # Print the actual data payload
        print("Data:", delivery_rate_metric.data())

    # ==============================================================================
    # Metric 2: TCP Minimum Round Trip Time (RTT)
    # ==============================================================================
    # This metric reflects the minimum RTT measured between sending a TCP packet
    # and receiving the acknowledgement.
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)

        # Get the metric object
        min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)

        # Print the description provided by the library
        print("Description:", min_rtt_metric.description())

        # Print the actual data payload
        print("Data:", min_rtt_metric.data())

TPU 指標的重新整理頻率

TPU 指標的重新整理頻率下限為一秒。 主機指標資料會以 1 Hz 的固定頻率匯出。這個匯出程序造成的延遲時間微不足道。LibTPU 的執行階段指標不受相同頻率限制。不過,為了保持一致性,這些指標也會以 1 Hz 或每秒 1 個樣本的頻率取樣。

TPU-Z 模組

TPU-Z 是 TPU 的遙測和偵錯設施。這項工具會提供附加至主機的所有 TPU 核心的詳細執行階段狀態資訊。這項功能由 tpuz 模組提供,該模組是 libtpu Python SDK 中 libtpu.sdk 模組的一部分。這個模組會提供每個核心狀態的快照。

TPU-Z 的主要用途是診斷分散式 TPU 工作負載中的停止或死結。您可以在主機上查詢 TPU-Z 服務,擷取每個核心的狀態,比較所有核心的程式計數器、HLO 位置和執行 ID,找出異常狀況。

libtpu.sdk 程式庫中使用 get_core_state_summary() 函式,顯示 TPU-Z 指標:

summary = sdk.tpuz.get_core_state_summary()

TPU-Z 指標的輸出內容會以字典形式提供。以下是單一核心的截斷範例:

{
  "host_name": "my-tpu-host-vm",
  "core_states": {
    "1": {
      "core_id": {
        "global_core_id": 1,
        "chip_id": 0,
        "core_on_chip": {
          "type": "TPU_CORE_TYPE_TENSOR_CORE",
          "index": 1
        }
      },
      "sequencer_info": [
        {
          "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
          "sequencer_index": 0,
          "pc": 4490,
          "program_id": 3274167277388825310,
          "run_id": 3
        }
      ],
      "program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
      "queued_program_info": [],
      "error_message": ""
    }
    // ...
  }
}

如要擷取每個核心的高階最佳化工具 (HLO) 相關資訊,請將 include_hlo_info 參數設為 True

summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)

輸出內容包含其他 HLO 資訊:

"1": {
    "core_id": {
      "global_core_id": 1,
      "chip_id": 0,
      "core_on_chip": {
        "type": "TPU_CORE_TYPE_TENSOR_CORE",
        "index": 1
      }
    },
    "sequencer_info": [
      {
        "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
        "sequencer_index": 0,
        "pc": 17776,
        "tag": 3,
        "tracemark": 2147483646,
        "program_id": 3230481660274331500,
        "run_id": 81,
        "hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
        "hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
      }
    ],
    "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
    "launch_id": 1394130914,
    "queued_program_info": [
      {
        "run_id": 81,
        "launch_id": 1394130914,
        "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
      }
    ]
  }

TPU-Z 指標

get_core_state_summary 函式會以字典形式傳回 TPU-Z 指標,結構如下。

CurrentCoreStateSummary

CurrentCoreStateSummary 字典提供個別 TPU 核心狀態的詳細摘要。

欄位 類型 說明
core_id dictionary TpuCoreIdentifier 字典,內含 TPU 核心的 ID 資訊。
sequencer_info 字典清單 SequencerInfo 字典清單,說明核心上每個定序器的狀態。
program_fingerprint 位元組 在這個核心上執行的程式指紋。
launch_id 整數 目前或最近一次執行的程式啟動 ID。
queued_program_info 字典清單 已排入執行佇列的程式字典清單。QueuedProgramInfo
error_message 字串 這個核心的任何錯誤訊息。

TpuCoreIdentifier

TpuCoreIdentifier 字典提供 TPU 系統中核心的 ID 資訊。

欄位 類型 說明
global_core_id 整數 核心的 ID。
chip_id 整數 核心所屬晶片的 ID。
core_on_chip dictionary TpuCoreOnChip 字典,說明核心的類型和在晶片上的索引。

TpuCoreOnChip

TpuCoreOnChip 字典包含特定晶片中核心屬性的相關資訊。

欄位 類型 說明
type 字串 TPU 核心類型。例如 TPU_CORE_TYPE_TENSOR_CORE
index 整數 晶片上核心的索引。

SequencerInfo

SequencerInfo 字典包含核心上單一定序器的狀態資訊。

欄位 類型 說明
sequencer_type 字串 定序器類型。例如 TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER
sequencer_index 整數 定序器索引 (如有相同類型的多個定序器)。
pc 整數 目前的程式計數器值。
program_id 整數 與特定程式執行個體相關聯的 ID,該程式執行個體會啟動,以便在 TPU 核心上執行。
run_id 整數 與在 TPU 核心上執行的特定程式執行個體相關聯的執行 ID。
hlo_location 字串 最佳化工具位置資訊概覽。
hlo_detailed_info 字串 高階最佳化工具的詳細資訊。

QueuedProgramInfo

QueuedProgramInfo 字典包含已排入佇列,準備在核心上執行的程式相關資訊。

欄位 類型 說明
run_id 整數 已加入佇列的程式執行 ID。
launch_id 整數 已加入佇列節目的啟動 ID。
program_fingerprint 位元組 已加入佇列節目的指紋。

搭配 JAX 使用 TPU-Z

您可以使用 libtpu.sdk 程式庫,在 JAX 工作負載中存取 TPU-Z 指標。下列 Python 指令碼使用 JAX 進行高效能張量運算,同時在背景執行緒中使用 libtpu SDK,監控基礎 TPU 硬體的狀態和活動。

包括下列 Python 套件:

import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk

monitor_tpu_status 函式會使用背景執行緒,在主要應用程式執行 JAX 工作負載時,持續顯示 TPU 核心的運作狀態。這項工具可即時診斷問題。

def monitor_tpu_status():
  """Monitors TPU status in a background thread."""

  while monitoring_active:
    try:
      summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
      if summary and 'core_states' in summary:
        print(summary)
      else:
        print('WARNING: Call returned an empty or invalid summary.')
    except RuntimeError as e:
      print(f'FAIL: Error calling API: {e}')
    except Exception as e:
      print(f'FAIL: Unexpected error in monitor thread: {e}')

    for _ in range(MONITORING_INTERVAL_SECONDS * 2):
      if not monitoring_active:
        break
      time.sleep(0.5)
  print('✅ Monitoring thread stopped.')

transformer_block 函式會實作 Transformer 架構的完整層,這是 LLM 的基礎建構區塊。

@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
  """A simplified but computationally intensive Transformer block."""
  # Multi-head Self-Attention
  qkv = jnp.dot(x, params['qkv_kernel'])
  q, k, v = jnp.array_split(qkv, 3, axis=-1)

  # Reshape for multi-head attention
  q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)

  # Scaled dot-product attention
  attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
  attention_weights = jax.nn.softmax(attention_scores, axis=-1)
  attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
  attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)

  attention_output = jnp.dot(attention_output, params['o_kernel'])

  # Residual connection and Layer Normalization 1
  h1 = x + attention_output
  h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
  h1_norm = h1_norm / jnp.sqrt(
      jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
  )

  # Feed-Forward Network
  ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
  ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])

  # Residual connection and Layer Normalization 2
  h2 = h1_norm + ffn_output
  h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
  h2_norm = h2_norm / jnp.sqrt(
      jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
  )

  return h2_norm

main 函式會協調 JAX 運算設定、啟動背景 TPU 監控,並執行主要工作負載迴圈。

def main():
  num_devices = jax.device_count()
  print(f"Running on {num_devices} devices.")

  batch_size = 128 * num_devices
  seq_len = 512
  embed_dim = 1024
  ffn_dim = embed_dim * 4

  key = jax.random.PRNGKey(0)

  params = {
      'qkv_kernel': jax.random.normal(
          key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
      ),
      'o_kernel': jax.random.normal(
          key, (embed_dim, embed_dim), dtype=jnp.bfloat16
      ),
      'ffn1_kernel': jax.random.normal(
          key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
      ),
      'ffn2_kernel': jax.random.normal(
          key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
      ),
  }
  input_data = jax.random.normal(
      key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
  )
  input_data = jax.device_put(input_data)
  monitor_thread = threading.Thread(target=monitor_tpu_status)
  monitor_thread.start()
  print("Starting JAX computation loop...")
  start_time = time.time()
  iterations = 0
  while time.time() - start_time < JOB_DURATION_SECONDS:
    result = transformer_block(params, input_data)
    result.block_until_ready()
    iterations += 1
    print(f'  -> Jax iteration {iterations} complete.', end='\r')

  print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")

  global monitoring_active
  monitoring_active = False
  monitor_thread.join()

if __name__ == '__main__':
  main()

疑難排解

本節提供疑難排解資訊,協助您找出並解決使用 TPU 監控程式庫時可能遇到的問題。

缺少功能或指標

如果無法查看某些功能或指標,最常見的原因是libtpu版本過舊。TPU 監控程式庫功能和指標會納入 libtpu 版本,舊版可能缺少新功能和指標。

檢查環境中執行的 libtpu 版本:

指令列:

pip show libtpu

Python:

import libtpu

print(libtpu.__version__)

如果不是使用 libtpu最新版本,請使用下列指令更新程式庫:

pip install --upgrade libtpu