TPU 監控程式庫
透過進階 TPU 監控功能,深入瞭解 Cloud TPU 硬體的效能和行為,這些功能直接建構在基礎軟體層 LibTPU 上。LibTPU 包含驅動程式、網路程式庫、XLA 編譯器和 TPU 執行階段,可用於與 TPU 互動,但本文重點是 TPU 監控程式庫。
TPU 監控程式庫提供以下功能:
全面監控:存取遙測 API 和指標套件,深入瞭解 TPU 的運作效能和特定行為。
診斷工具包:提供 SDK 和指令列介面 (CLI),可對 TPU 資源進行偵錯及深入分析效能。
這些監控功能是專為客戶設計的頂層解決方案,可提供必要工具,協助您有效最佳化 TPU 工作負載。
TPU 監控程式庫會提供詳細資訊,說明機器學習工作負載在 TPU 硬體上的執行情況。這項工具可協助您瞭解 TPU 使用率、找出瓶頸,以及排解效能問題。這項指標提供的詳細資訊比中斷指標、有效輸送量指標和其他指標更豐富。
開始使用 TPU 監控程式庫
取得這些實用洞察資料的方式很簡單。TPU 監控功能已整合至 LibTPU SDK,因此安裝 LibTPU 時會一併安裝這項功能。
安裝 LibTPU
pip install libtpu
此外,LibTPU 更新會與 JAX 版本同步,也就是說,安裝最新 JAX 版本 (每月發布) 時,通常會將您固定在最新相容的 LibTPU 版本及其功能。
安裝 JAX
pip install -U "jax[tpu]"
PyTorch 使用者安裝 PyTorch/XLA 後,即可取得最新的 LibTPU 和 TPU 監控功能。
安裝 PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
如要進一步瞭解如何安裝 PyTorch/XLA,請參閱 PyTorch/XLA GitHub 存放區中的「安裝」一節。
在 Python 中匯入程式庫
如要開始使用 TPU 監控程式庫,您需要在 Python 程式碼中匯入 libtpu 模組。
from libtpu.sdk import tpumonitoring
列出所有支援的功能
列出所有指標名稱和支援的功能:
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
支援的指標
下列程式碼範例說明如何列出所有支援的指標名稱:
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
下表列出所有指標及其對應定義:
| 指標 | 定義 | API 的指標名稱 | 範例值 |
|---|---|---|---|
| Tensor Core 使用率 | 測量 TensorCore 的用量百分比,計算方式為 TensorCore 作業的占比。每 1 秒取樣 10 微秒。取樣率無法修改。 您可以透過這項指標,監控 TPU 裝置中的工作負載效率。 |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# utilization percentage for accelerator ID 0-3. |
| 工作週期百分比 | 過往取樣期間 (每 5 秒;可透過設定 LIBTPU_INIT_ARG 標記調整) 內,加速器積極處理作業的時間占比 (按照上一個取樣期間內執行 HLO 程式的週期記錄)。這項指標代表 TPU 的忙碌程度,並會針對每個晶片發出。 |
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# Duty cycle percentage for accelerator ID 0-3. |
| HBM 容量總計 | 這項指標會以位元組為單位,回報 HBM 總容量。 |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# Total HBM capacity in bytes that attached to accelerator ID 0-3. |
| HBM 容量用量 | 這項指標會回報過去取樣期間 (每 5 秒;可透過設定 LIBTPU_INIT_ARG 旗標調整) 的 HBM 容量用量 (以位元組為單位)。 |
hbm_capacity_usage
|
['100', '200', '300', '400']
# Capacity usage for HBM in bytes that attached to accelerator ID 0-3. |
| 緩衝區傳輸延遲 | 巨型多切片流量的網路傳輸延遲。 這項視覺化功能可協助您瞭解整體網路效能環境。 |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution. |
| 高階作業執行時間分配指標 | 提供 HLO 編譯二進位檔執行狀態的精細效能洞察資料,可偵測迴歸並進行模型層級的偵錯。 |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999. |
| 高階最佳化工具佇列大小 | HLO 執行佇列大小監控會追蹤等待或正在執行的已編譯 HLO 程式數量。這項指標會顯示執行管道壅塞情形,有助於找出硬體執行、驅動程式負擔或資源分配方面的效能瓶頸。 |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# Measures queue size for CoreType-CoreID. |
| 集體端對端延遲時間 | 這項指標會測量 DCN 的端對端集體延遲時間 (以微秒為單位),從主機啟動作業到所有對等互連接收輸出內容。包括主機端資料縮減,以及將輸出內容傳送至 TPU。結果是字串,詳細說明緩衝區大小、類型,以及平均、第 50、90、95 和 99.9 個百分位數的延遲時間。 |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency. |
| 傳輸層的往返延遲時間 | gRPC 用於多切片 TPU 流量的 TCP 連線所觀察到的最短往返時間 (RTT) 分佈。 |
grpc_tcp_min_round_trip_times
|
['27.63, 29.03, 38.52, 41.63, 52.74']
# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs). |
| 傳輸層的處理量 | gRPC 用於多重切片 TPU 流量的 TCP 連線,其近期總處理量的累積分布。 |
grpc_tcp_delivery_rates
|
['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']
# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs). |
讀取指標資料
如要讀取指標資料,請在呼叫 tpumonitoring.get_metric 函式時指定指標名稱。您可以在效能不佳的程式碼中插入臨時指標檢查,判斷效能問題是源自軟體還是硬體。
以下程式碼範例說明如何讀取 duty_cycle 指標:
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
使用指標檢查 TPU 使用率
下列範例說明如何使用 TPU 監控程式庫中的指標,追蹤 TPU 使用率。
在 JAX 訓練期間監控 TPU 任務週期
情境:您正在執行 JAX 訓練指令碼,並想在整個訓練過程中監控 TPU 的 duty_cycle_pct 指標,確認 TPU 得到有效運用。您可以在訓練期間定期記錄這項指標,追蹤 TPU 使用率。
下列程式碼範例說明如何在 JAX 訓練期間監控 TPU 負載週期:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
執行 JAX 推論前,請先檢查 HBM 使用率
情境: 使用 JAX 模型執行推論前,請先檢查 TPU 的 HBM (高頻寬記憶體) 使用率,確認有足夠的可用記憶體,並在推論開始前取得基準測量結果。
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
查看網路指標
情境: 您正在執行多主機和多配量工作負載,並想使用 SSH 連線至其中一個 GKE Pod 或 TPU,以便在工作負載執行時查看網路指標。這些指令也可以直接併入多主機工作負載。
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and training setup goes here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
# --- details here ---
# ==============================================================================
# Metric 1: TCP Delivery Rate
# ==============================================================================
# This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)
# Get the metric object
delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)
# Print the description provided by the library
print("Description:", delivery_rate_metric.description())
# Print the actual data payload
print("Data:", delivery_rate_metric.data())
# ==============================================================================
# Metric 2: TCP Minimum Round Trip Time (RTT)
# ==============================================================================
# This metric reflects the minimum RTT measured between sending a TCP packet
# and receiving the acknowledgement.
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)
# Get the metric object
min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)
# Print the description provided by the library
print("Description:", min_rtt_metric.description())
# Print the actual data payload
print("Data:", min_rtt_metric.data())
TPU 指標的重新整理頻率
TPU 指標的重新整理頻率下限為一秒。 主機指標資料會以 1 Hz 的固定頻率匯出。這個匯出程序造成的延遲時間微不足道。LibTPU 的執行階段指標不受相同頻率限制。不過,為了保持一致性,這些指標也會以 1 Hz 或每秒 1 個樣本的頻率取樣。
TPU-Z 模組
TPU-Z 是 TPU 的遙測和偵錯設施。這項工具會提供附加至主機的所有 TPU 核心的詳細執行階段狀態資訊。這項功能由 tpuz 模組提供,該模組是 libtpu Python SDK 中 libtpu.sdk 模組的一部分。這個模組會提供每個核心狀態的快照。
TPU-Z 的主要用途是診斷分散式 TPU 工作負載中的停止或死結。您可以在主機上查詢 TPU-Z 服務,擷取每個核心的狀態,比較所有核心的程式計數器、HLO 位置和執行 ID,找出異常狀況。
在 libtpu.sdk 程式庫中使用 get_core_state_summary() 函式,顯示 TPU-Z 指標:
summary = sdk.tpuz.get_core_state_summary()
TPU-Z 指標的輸出內容會以字典形式提供。以下是單一核心的截斷範例:
{
"host_name": "my-tpu-host-vm",
"core_states": {
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 4490,
"program_id": 3274167277388825310,
"run_id": 3
}
],
"program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
"queued_program_info": [],
"error_message": ""
}
// ...
}
}
如要擷取每個核心的高階最佳化工具 (HLO) 相關資訊,請將 include_hlo_info 參數設為 True:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
輸出內容包含其他 HLO 資訊:
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 17776,
"tag": 3,
"tracemark": 2147483646,
"program_id": 3230481660274331500,
"run_id": 81,
"hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
"hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
}
],
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
"launch_id": 1394130914,
"queued_program_info": [
{
"run_id": 81,
"launch_id": 1394130914,
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
}
]
}
TPU-Z 指標
get_core_state_summary 函式會以字典形式傳回 TPU-Z 指標,結構如下。
CurrentCoreStateSummary
CurrentCoreStateSummary 字典提供個別 TPU 核心狀態的詳細摘要。
| 欄位 | 類型 | 說明 |
|---|---|---|
core_id |
dictionary | TpuCoreIdentifier 字典,內含 TPU 核心的 ID 資訊。 |
sequencer_info |
字典清單 | SequencerInfo 字典清單,說明核心上每個定序器的狀態。 |
program_fingerprint |
位元組 | 在這個核心上執行的程式指紋。 |
launch_id |
整數 | 目前或最近一次執行的程式啟動 ID。 |
queued_program_info |
字典清單 | 已排入執行佇列的程式字典清單。QueuedProgramInfo |
error_message |
字串 | 這個核心的任何錯誤訊息。 |
TpuCoreIdentifier
TpuCoreIdentifier 字典提供 TPU 系統中核心的 ID 資訊。
| 欄位 | 類型 | 說明 |
|---|---|---|
global_core_id |
整數 | 核心的 ID。 |
chip_id |
整數 | 核心所屬晶片的 ID。 |
core_on_chip |
dictionary | TpuCoreOnChip 字典,說明核心的類型和在晶片上的索引。 |
TpuCoreOnChip
TpuCoreOnChip 字典包含特定晶片中核心屬性的相關資訊。
| 欄位 | 類型 | 說明 |
|---|---|---|
type |
字串 | TPU 核心類型。例如 TPU_CORE_TYPE_TENSOR_CORE。 |
index |
整數 | 晶片上核心的索引。 |
SequencerInfo
SequencerInfo 字典包含核心上單一定序器的狀態資訊。
| 欄位 | 類型 | 說明 |
|---|---|---|
sequencer_type |
字串 | 定序器類型。例如 TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER。 |
sequencer_index |
整數 | 定序器索引 (如有相同類型的多個定序器)。 |
pc |
整數 | 目前的程式計數器值。 |
program_id |
整數 | 與特定程式執行個體相關聯的 ID,該程式執行個體會啟動,以便在 TPU 核心上執行。 |
run_id |
整數 | 與在 TPU 核心上執行的特定程式執行個體相關聯的執行 ID。 |
hlo_location |
字串 | 最佳化工具位置資訊概覽。 |
hlo_detailed_info |
字串 | 高階最佳化工具的詳細資訊。 |
QueuedProgramInfo
QueuedProgramInfo 字典包含已排入佇列,準備在核心上執行的程式相關資訊。
| 欄位 | 類型 | 說明 |
|---|---|---|
run_id |
整數 | 已加入佇列的程式執行 ID。 |
launch_id |
整數 | 已加入佇列節目的啟動 ID。 |
program_fingerprint |
位元組 | 已加入佇列節目的指紋。 |
搭配 JAX 使用 TPU-Z
您可以使用 libtpu.sdk 程式庫,在 JAX 工作負載中存取 TPU-Z 指標。下列 Python 指令碼使用 JAX 進行高效能張量運算,同時在背景執行緒中使用 libtpu SDK,監控基礎 TPU 硬體的狀態和活動。
包括下列 Python 套件:
import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk
monitor_tpu_status 函式會使用背景執行緒,在主要應用程式執行 JAX 工作負載時,持續顯示 TPU 核心的運作狀態。這項工具可即時診斷問題。
def monitor_tpu_status():
"""Monitors TPU status in a background thread."""
while monitoring_active:
try:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
if summary and 'core_states' in summary:
print(summary)
else:
print('WARNING: Call returned an empty or invalid summary.')
except RuntimeError as e:
print(f'FAIL: Error calling API: {e}')
except Exception as e:
print(f'FAIL: Unexpected error in monitor thread: {e}')
for _ in range(MONITORING_INTERVAL_SECONDS * 2):
if not monitoring_active:
break
time.sleep(0.5)
print('✅ Monitoring thread stopped.')
transformer_block 函式會實作 Transformer 架構的完整層,這是 LLM 的基礎建構區塊。
@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
"""A simplified but computationally intensive Transformer block."""
# Multi-head Self-Attention
qkv = jnp.dot(x, params['qkv_kernel'])
q, k, v = jnp.array_split(qkv, 3, axis=-1)
# Reshape for multi-head attention
q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
# Scaled dot-product attention
attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)
attention_output = jnp.dot(attention_output, params['o_kernel'])
# Residual connection and Layer Normalization 1
h1 = x + attention_output
h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
h1_norm = h1_norm / jnp.sqrt(
jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
)
# Feed-Forward Network
ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])
# Residual connection and Layer Normalization 2
h2 = h1_norm + ffn_output
h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
h2_norm = h2_norm / jnp.sqrt(
jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
)
return h2_norm
main 函式會協調 JAX 運算設定、啟動背景 TPU 監控,並執行主要工作負載迴圈。
def main():
num_devices = jax.device_count()
print(f"Running on {num_devices} devices.")
batch_size = 128 * num_devices
seq_len = 512
embed_dim = 1024
ffn_dim = embed_dim * 4
key = jax.random.PRNGKey(0)
params = {
'qkv_kernel': jax.random.normal(
key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
),
'o_kernel': jax.random.normal(
key, (embed_dim, embed_dim), dtype=jnp.bfloat16
),
'ffn1_kernel': jax.random.normal(
key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
),
'ffn2_kernel': jax.random.normal(
key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
),
}
input_data = jax.random.normal(
key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
)
input_data = jax.device_put(input_data)
monitor_thread = threading.Thread(target=monitor_tpu_status)
monitor_thread.start()
print("Starting JAX computation loop...")
start_time = time.time()
iterations = 0
while time.time() - start_time < JOB_DURATION_SECONDS:
result = transformer_block(params, input_data)
result.block_until_ready()
iterations += 1
print(f' -> Jax iteration {iterations} complete.', end='\r')
print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")
global monitoring_active
monitoring_active = False
monitor_thread.join()
if __name__ == '__main__':
main()
疑難排解
本節提供疑難排解資訊,協助您找出並解決使用 TPU 監控程式庫時可能遇到的問題。
缺少功能或指標
如果無法查看某些功能或指標,最常見的原因是libtpu版本過舊。TPU 監控程式庫功能和指標會納入 libtpu 版本,舊版可能缺少新功能和指標。
檢查環境中執行的 libtpu 版本:
指令列:
pip show libtpu
Python:
import libtpu
print(libtpu.__version__)
如果不是使用 libtpu 的最新版本,請使用下列指令更新程式庫:
pip install --upgrade libtpu