TPU 监控库
借助直接基于基础软件层 LibTPU 构建的高级 TPU 监控功能,深入了解 Cloud TPU 硬件的性能和行为。虽然 LibTPU 包含用于与 TPU 交互的驱动程序、网络库、XLA 编译器和 TPU 运行时,但本文档的重点是 TPU 监控库。
TPU 监控库提供以下功能:
全面的可观测性:访问遥测 API 和指标套件,详细了解 TPU 的运行性能和具体行为。
诊断工具包:提供 SDK 和命令行界面 (CLI),旨在对 TPU 资源进行调试和深入的性能分析。
这些监控功能适合面向客户的顶级解决方案,为您提供有效优化 TPU 工作负载所需的基本工具。
借助 TPU 监控库,您可以详细了解机器学习工作负载在 TPU 硬件上的运行情况。它旨在帮助您了解 TPU 利用率、找出瓶颈并调试性能问题。与中断指标、goodput 指标和其他指标相比,它可为您提供更详细的信息。
开始使用 TPU 监控库
您可以轻松获取强有力的数据分析。 TPU 监控功能与 LibTPU SDK 集成,因此在安装 LibTPU 时,该功能也会一并安装。
安装 LibTPU
pip install libtpu
或者,LibTPU 更新与 JAX 版本同步,这意味着当您安装最新的 JAX 版本(每月发布)时,通常会将您固定到最新的兼容 LibTPU 版本及其功能。
安装 JAX
pip install -U "jax[tpu]"
对于 PyTorch 用户,安装 PyTorch/XLA 会提供最新的 LibTPU 和 TPU 监控功能。
安装 PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
如需详细了解如何安装 PyTorch/XLA,请参阅 PyTorch/XLA GitHub 仓库中的安装。
在 Python 中导入库
如需开始使用 TPU 监控库,您需要在 Python 代码中导入 libtpu 模块。
from libtpu.sdk import tpumonitoring
列出所有支持的功能
列出所有指标名称及其支持的功能:
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
支持的指标
以下代码示例展示了如何列出所有受支持的指标名称:
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
下表展示了所有指标及其对应的定义:
| 指标 | 定义 | API 的指标名称 | 示例值 |
|---|---|---|---|
| Tensor Core 利用率 | 衡量 TensorCore 用量的百分比,以属于 TensorCore 操作一部分的操作百分比形式计算。每 1 秒采样 10 微秒。您无法修改采样率。借助此指标,您可以监控 TPU 设备上工作负载的效率。 |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# 加速器 ID 0-3 的利用率百分比。 |
| 占空比百分比 | 加速器活跃处理(通过用于上一个采样周期内执行 HLO 程序的周期记录)的时间占过去的采样周期(每 5 秒;可通过设置 LIBTPU_INIT_ARG 标志进行调整)的百分比。此指标表示 TPU 的繁忙程度。该指标按芯片发出。
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# 加速器 ID 0-3 的占空比百分比。 |
| HBM 总容量 | 此指标报告 HBM 总容量(以字节为单位)。 |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# 附加到加速器 ID 0-3 的 HBM 总容量(以字节为单位)。 |
| HBM 容量用量 | 此指标报告过去的采样周期(每 5 秒;可通过设置 LIBTPU_INIT_ARG 标志进行调整)的 HBM 容量用量(以字节为单位)。
|
hbm_capacity_usage
|
['100', '200', '300', '400']
# 附加到加速器 ID 0-3 的 HBM 容量用量(以字节为单位)。 |
| 缓冲区传输延迟时间 | 超大规模多切片流量的网络传输延迟时间。 这种可视化图表可让您了解整体网络性能环境。 |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# 缓冲区空间,网络传输延迟时间分布的平均值、p50、p90、p99、p99.9。 |
| 操作执行时间分布概要指标 | 针对 HLO 编译二进制文件执行状态提供详细的性能数据分析,以便进行回归检测和模型级调试。 |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# CoreType-CoreID 的 HLO 执行时间时长分布,包含平均值、p50、p90、p95、p999。 |
| 高级优化器队列大小 | HLO 执行队列大小监控会跟踪正在等待执行或正在执行的已编译 HLO 程序的数量。此指标展示了执行流水线拥塞情况,从而能够找出硬件执行、驱动程序开销或资源分配中的性能瓶颈。 |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# 衡量 CoreType-CoreID 的队列大小。 |
| 总体端到端延迟时间 | 此指标用于衡量 DCN 上从发起操作的主机到接收输出的所有对等方的端到端集体延迟时间(以微秒为单位)。它包括主机端数据缩减和向 TPU 发送输出。结果是字符串,详细说明了缓冲区空间、类型以及平均延迟时间、p50、p90、p95 和 p99.9 延迟时间。 |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# 传输大小-集合操作,集合端到端延迟时间的平均值、p50、p90、p95、p999。 |
| 传输层的往返延迟时间 | gRPC 用于多切片 TPU 流量的 TCP 连接上观测到的最短往返时间 (RTT) 的分布。 |
grpc_tcp_min_round_trip_times
|
['27.63, 29.03, 38.52, 41.63, 52.74']
# 表示分布的平均值、p50、p90、p95 和 p99.9 百分位,以微秒 (µs) 为单位。 |
| 传输层吞吐量 | gRPC 用于多切片 TPU 流量的 TCP 连接的近期吞吐量的累积分布。 |
grpc_tcp_delivery_rates
|
['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']
# 表示分布的平均值、p50、p90、p95 和 p99.9 百分位,以微秒 (µs) 为单位。 |
读取指标数据
如需读取指标数据,请在调用 tpumonitoring.get_metric 函数时指定指标名称。您可以将临时指标检查插入到低性能代码中,以确定性能问题是源自软件还是硬件。
以下代码示例展示了如何读取 duty_cycle 指标:
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
使用指标检查 TPU 利用率
以下示例展示了如何使用 TPU 监控库中的指标来跟踪 TPU 利用率。
在 JAX 训练期间监控 TPU 占空比
场景:您正在运行 JAX 训练脚本,并希望在整个训练过程中监控 TPU 的 duty_cycle_pct 指标,以确认 TPU 是否得到了有效利用。您可以在训练期间定期记录此指标,以跟踪 TPU 利用率。
以下代码示例展示了如何在 JAX 训练期间监控 TPU 占空比:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
在运行 JAX 推理之前检查 HBM 利用率
场景:在使用 JAX 模型运行推理之前,请检查 TPU 上的当前 HBM(高带宽内存)利用率,以确认您有足够的可用内存,并在推理开始前获取基准测量结果。
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
查看网络指标
场景:您正在运行多主机和多切片工作负载,并希望使用 SSH 连接到其中一个 GKE Pod 或 TPU,以便在工作负载运行时查看网络指标。这些命令也可以直接纳入多主机工作负载中。
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and training setup goes here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
# --- details here ---
# ==============================================================================
# Metric 1: TCP Delivery Rate
# ==============================================================================
# This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)
# Get the metric object
delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)
# Print the description provided by the library
print("Description:", delivery_rate_metric.description())
# Print the actual data payload
print("Data:", delivery_rate_metric.data())
# ==============================================================================
# Metric 2: TCP Minimum Round Trip Time (RTT)
# ==============================================================================
# This metric reflects the minimum RTT measured between sending a TCP packet
# and receiving the acknowledgement.
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)
# Get the metric object
min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)
# Print the description provided by the library
print("Description:", min_rtt_metric.description())
# Print the actual data payload
print("Data:", min_rtt_metric.data())
TPU 指标的刷新频率
TPU 指标的刷新频率仅限于最短 1 秒。 主机指标数据以 1 Hz 的固定频率导出。此导出过程导致的延迟时间可以忽略不计。LibTPU 中的运行时指标不受相同的频率限制条件的约束。但是,为了保持一致性,这些指标的采样频率也为 1 Hz,即每秒 1 次采样。
TPU-Z 模块
TPU-Z 是 TPU 的遥测和调试工具。它可提供连接到主机的所有 TPU 核心的详细运行时状态信息。该功能通过 tpuz 模块(即 libtpu Python SDK 中 libtpu.sdk 模块的一部分)提供。该模块提供每个核心的状态快照。
TPU-Z 的主要应用场景是诊断分布式 TPU 工作负载中的挂起或死锁。您可以查询主机上的 TPU-Z 服务,以捕获每个核心的状态,比较所有核心的程序计数器、HLO 位置和运行 ID,以识别异常情况。
使用 libtpu.sdk 库中的 get_core_state_summary() 函数显示 TPU-Z 指标:
summary = sdk.tpuz.get_core_state_summary()
TPU-Z 指标的输出以字典形式提供。以下是截断的单个核心示例:
{
"host_name": "my-tpu-host-vm",
"core_states": {
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 4490,
"program_id": 3274167277388825310,
"run_id": 3
}
],
"program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
"queued_program_info": [],
"error_message": ""
}
// ...
}
}
如需检索有关每个核心上的高级优化器 (HLO) 的信息,请将 include_hlo_info 参数设置为 True:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
输出包含其他 HLO 信息:
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 17776,
"tag": 3,
"tracemark": 2147483646,
"program_id": 3230481660274331500,
"run_id": 81,
"hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
"hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
}
],
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
"launch_id": 1394130914,
"queued_program_info": [
{
"run_id": 81,
"launch_id": 1394130914,
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
}
]
}
TPU-Z 指标
get_core_state_summary 函数以具有以下结构的字典的形式返回 TPU-Z 指标。
CurrentCoreStateSummary
CurrentCoreStateSummary 字典提供单个 TPU 核心的状态的详细摘要。
| 字段 | 类型 | 说明 |
|---|---|---|
core_id |
字典 | 一个 TpuCoreIdentifier 字典,其中包含有关 TPU 核心的 ID 信息。 |
sequencer_info |
字典列表 | 一个 SequencerInfo 字典列表,描述核心上每个序列器的状态。 |
program_fingerprint |
字节 | 在相应核心上执行的程序的指纹。 |
launch_id |
整数 | 当前或最新程序的启动 ID。 |
queued_program_info |
字典列表 | 一个 QueuedProgramInfo 字典列表,用于存储已排队等待执行的程序。 |
error_message |
字符串 | 相应核心的任何错误消息。 |
TpuCoreIdentifier
TpuCoreIdentifier 字典提供 TPU 系统中的核心的 ID 信息。
| 字段 | 类型 | 说明 |
|---|---|---|
global_core_id |
整数 | 核心的 ID。 |
chip_id |
整数 | 核心所属芯片的 ID。 |
core_on_chip |
字典 | 一个 TpuCoreOnChip 字典,描述核心的类型及其在芯片上的索引。 |
TpuCoreOnChip
TpuCoreOnChip 字典包含有关特定芯片中的核心属性的信息。
| 字段 | 类型 | 说明 |
|---|---|---|
type |
字符串 | TPU 核心的类型。例如:TPU_CORE_TYPE_TENSOR_CORE。 |
index |
整数 | 芯片上核心的索引。 |
SequencerInfo
SequencerInfo 字典包含有关核心上单个序列器的状态的信息。
| 字段 | 类型 | 说明 |
|---|---|---|
sequencer_type |
字符串 | 序列器的类型。例如:TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER。 |
sequencer_index |
整数 | 序列器的索引(如果有多个相同类型的序列器)。 |
pc |
整数 | 当前程序计数器值。 |
program_id |
整数 | 与要启动以在 TPU 核心上执行的特定程序实例关联的 ID。 |
run_id |
整数 | 与在 TPU 核心上执行的特定程序实例关联的运行 ID。 |
hlo_location |
字符串 | 高级优化器的位置信息。 |
hlo_detailed_info |
字符串 | 详细的高级优化器信息。 |
QueuedProgramInfo
QueuedProgramInfo 字典包含有关已排队等待在核心上执行的程序的信息。
| 字段 | 类型 | 说明 |
|---|---|---|
run_id |
整数 | 已排队程序的运行 ID。 |
launch_id |
整数 | 已排队程序的启动 ID。 |
program_fingerprint |
字节 | 已排队程序的指纹。 |
将 TPU-Z 与 JAX 搭配使用
您可以通过 libtpu.sdk 库在 JAX 工作负载中访问 TPU-Z 指标。
以下 Python 脚本使用 JAX 进行高性能张量计算,同时在后台线程中使用 libtpu SDK 监控底层 TPU 硬件的状态和活动。
包括以下 Python 软件包:
import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk
monitor_tpu_status 函数使用后台线程持续显示 TPU 核心的运行状态,而主应用则执行 JAX 工作负载。它充当实时诊断工具。
def monitor_tpu_status():
"""Monitors TPU status in a background thread."""
while monitoring_active:
try:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
if summary and 'core_states' in summary:
print(summary)
else:
print('WARNING: Call returned an empty or invalid summary.')
except RuntimeError as e:
print(f'FAIL: Error calling API: {e}')
except Exception as e:
print(f'FAIL: Unexpected error in monitor thread: {e}')
for _ in range(MONITORING_INTERVAL_SECONDS * 2):
if not monitoring_active:
break
time.sleep(0.5)
print('✅ Monitoring thread stopped.')
transformer_block 函数会实现 Transformer 架构的完整层,这是 LLM 的基础构建块。
@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
"""A simplified but computationally intensive Transformer block."""
# Multi-head Self-Attention
qkv = jnp.dot(x, params['qkv_kernel'])
q, k, v = jnp.array_split(qkv, 3, axis=-1)
# Reshape for multi-head attention
q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
# Scaled dot-product attention
attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)
attention_output = jnp.dot(attention_output, params['o_kernel'])
# Residual connection and Layer Normalization 1
h1 = x + attention_output
h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
h1_norm = h1_norm / jnp.sqrt(
jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
)
# Feed-Forward Network
ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])
# Residual connection and Layer Normalization 2
h2 = h1_norm + ffn_output
h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
h2_norm = h2_norm / jnp.sqrt(
jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
)
return h2_norm
main 函数可编排 JAX 计算的设置、启动后台 TPU 监控,并运行主工作负载循环。
def main():
num_devices = jax.device_count()
print(f"Running on {num_devices} devices.")
batch_size = 128 * num_devices
seq_len = 512
embed_dim = 1024
ffn_dim = embed_dim * 4
key = jax.random.PRNGKey(0)
params = {
'qkv_kernel': jax.random.normal(
key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
),
'o_kernel': jax.random.normal(
key, (embed_dim, embed_dim), dtype=jnp.bfloat16
),
'ffn1_kernel': jax.random.normal(
key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
),
'ffn2_kernel': jax.random.normal(
key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
),
}
input_data = jax.random.normal(
key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
)
input_data = jax.device_put(input_data)
monitor_thread = threading.Thread(target=monitor_tpu_status)
monitor_thread.start()
print("Starting JAX computation loop...")
start_time = time.time()
iterations = 0
while time.time() - start_time < JOB_DURATION_SECONDS:
result = transformer_block(params, input_data)
result.block_until_ready()
iterations += 1
print(f' -> Jax iteration {iterations} complete.', end='\r')
print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")
global monitoring_active
monitoring_active = False
monitor_thread.join()
if __name__ == '__main__':
main()
问题排查
本部分提供的问题排查信息可帮助您识别和解决在使用 TPU 监控库时可能遇到的问题。
缺少功能或指标
如果您无法查看某些功能或指标,最常见的原因是 libtpu 版本过时。TPU 监控库功能和指标包含在 libtpu 版本中,过时的版本可能缺少新功能和指标。
检查环境中运行的 libtpu 的版本:
命令行:
pip show libtpu
Python:
import libtpu
print(libtpu.__version__)
如果您未使用 libtpu 的最新版本,请使用以下命令更新库:
pip install --upgrade libtpu