TPU 모니터링 라이브러리
기본 소프트웨어 계층 LibTPU 위에 직접 빌드된 고급 TPU 모니터링 기능을 통해 Cloud TPU 하드웨어의 성능과 동작을 심층적으로 분석할 수 있습니다. LibTPU는 TPU와 상호작용할 수 있도록 드라이버, 네트워킹 라이브러리, XLA 컴파일러, TPU 런타임 등을 포함하고 있지만 이 문서에서는 TPU 모니터링 라이브러리에 초점을 맞춰 설명합니다.
TPU 모니터링 라이브러리는 다음을 제공합니다.
포괄적인 모니터링 가능성: TPU의 운영 성능과 특정 동작에 대한 상세한 정보를 제공하는 원격 분석 API 및 측정항목 모음에 액세스합니다.
진단 툴킷: TPU 리소스를 디버깅하고 심층적인 성능 분석을 수행할 수 있도록 설계된 SDK 및 명령줄 인터페이스(CLI)를 제공합니다.
이러한 모니터링 기능은 고객을 대상으로 하는 최상위 솔루션으로 설계되어 TPU 워크로드를 효과적으로 최적화하는 핵심 도구를 제공합니다.
TPU 모니터링 라이브러리는 머신러닝 워크로드가 TPU 하드웨어에서 수행되는 방법에 대한 상세한 정보를 제공합니다. TPU 사용률을 파악하고 병목 현상을 식별하며 성능 문제를 디버깅할 수 있도록 설계되었습니다. 중단 측정항목, 굿풋 측정항목, 기타 측정항목보다 더 상세한 정보를 제공합니다.
TPU 모니터링 라이브러리 시작하기
이러한 고급 정보에 간단히 액세스할 수 있습니다. TPU 모니터링 기능은 LibTPU SDK에 통합되어 있으므로 LibTPU를 설치하면 함께 제공됩니다.
LibTPU 설치
pip install libtpu
또는 LibTPU 업데이트는 JAX 출시 버전과 연동되어 있습니다. 즉, 매달 출시되는 최신 JAX 출시 버전을 설치하면 일반적으로 최신 호환 LibTPU 버전과 그 기능이 함께 설치됩니다.
JAX 설치
pip install -U "jax[tpu]"
PyTorch 사용자의 경우 PyTorch/XLA를 설치하면 최신 LibTPU 및 TPU 모니터링 기능이 제공됩니다.
PyTorch/XLA 설치
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
PyTorch/XLA 설치에 대한 자세한 내용은 PyTorch/XLA GitHub 저장소의 설치를 참조하세요.
Python에서 라이브러리 가져오기
TPU 모니터링 라이브러리를 사용하려면 Python 코드에서 libtpu 모듈을 가져와야 합니다.
from libtpu.sdk import tpumonitoring
지원되는 모든 기능 나열
지원되는 모든 측정항목 이름과 기능을 나열합니다.
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
지원되는 측정항목
다음 코드 샘플에서는 지원되는 모든 측정항목 이름을 나열하는 방법을 보여줍니다.
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
다음 표에서는 모든 측정항목과 해당 정의를 보여줍니다.
| 측정항목 | 정의 | API의 측정항목 이름 | 예시 값 |
|---|---|---|---|
| TensorCore 사용률 | TensorCore 작업의 일부로 수행된 작업의 비율을 기준으로 계산된 TensorCore 사용량 비율입니다. 1초마다 10마이크로초씩 샘플링됩니다. 샘플링 레이트를 수정할 수 없습니다. 이 측정항목을 사용하면 TPU 기기에서 워크로드 효율성을 모니터링할 수 있습니다. |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# 가속기 ID 0~3의 사용률 |
| 듀티 사이클 백분율 | 지난 샘플링 기간(5초 간격, LIBTPU_INIT_ARG 플래그 설정을 통해 조정 가능) 동안 가속기가 활발하게 처리 중이었던 시간의 비율로, 지난 샘플링 기간 동안 HLO 프로그램을 실행하는 데 사용된 사이클과 함께 기록됩니다. 이 측정항목은 TPU가 얼마나 많이 사용되고 있는지를 나타내며 칩 단위로 보고됩니다.
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# 가속기 ID 0~3에 대한 듀티 사이클 비율 |
| HBM 용량 합계 | 이 측정항목은 총 HBM 용량(바이트)을 보고합니다. |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# 가속기 ID 0~3에 연결된 총 HBM 용량(바이트) |
| HBM 용량 사용량 | 이 측정항목은 지난 샘플링 주기(5초 간격, LIBTPU_INIT_ARG 플래그를 설정하여 조정 가능) 동안의 HBM 용량 사용량(바이트)을 보고합니다.
|
hbm_capacity_usage
|
['100', '200', '300', '400']
# 가속기 ID 0~3에 연결된 HBM 용량 사용량(바이트) |
| 버퍼 전송 지연 시간 | 메가스케일 멀티슬라이스 트래픽에 대한 네트워크 전송 지연 시간입니다. 이 시각화를 통해 전체적인 네트워크 성능 환경을 파악할 수 있습니다. |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# 네트워크 전송 지연 시간 분포에 대한 버퍼 크기, 평균, p50, p90, p99, p99.9 |
| 고수준 작업 실행 시간 분포 측정항목 | HLO로 컴파일된 바이너리 실행 상태에 대한 세부적인 성능 통계를 제공하여 회귀 감지 및 모델 수준 디버깅을 지원합니다. |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# CoreType-CoreID에 대한 HLO 실행 시간 분포(평균, p50, p90, p95, p999) |
| 고수준 옵티마이저 큐 크기 | HLO 실행 큐 크기 모니터링은 실행을 대기 중이거나 실행 중인 컴파일된 HLO 프로그램 수를 추적합니다. 이 측정항목은 실행 파이프라인 혼잡도를 보여주며 하드웨어 실행, 드라이버 오버헤드 또는 리소스 할당과 관련된 성능 병목 현상 식별을 지원합니다. |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# CoreType-CoreID의 큐 크기 측정 |
| 집합적 엔드 투 엔드 지연 시간 | 이 측정항목은 작업을 시작하는 호스트에서 출력을 수신하는 모든 피어까지 DCN을 통한 엔드 투 엔드 집합적 지연 시간(마이크로초)을 측정합니다. 여기에는 호스트 측 데이터 감소와 TPU에 출력 전송이 포함됩니다. 결과는 버퍼 크기, 유형, 평균, p50, p90, p95, p99.9 지연 시간을 자세히 설명하는 문자열입니다. |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# 집합적 엔드 투 엔드 지연 시간의 전송 크기-집합 작업, 평균, p50, p90, p95, p999 |
| 전송 계층에서의 왕복 지연 시간 | gRPC가 멀티슬라이스 TPU 트래픽에 사용하는 TCP 연결에서 관찰된 최소 왕복 시간(RTT) 분포입니다. |
grpc_tcp_min_round_trip_times
|
['27.63, 29.03, 38.52, 41.63, 52.74']
# 분포의 평균, p50, p90, p95, p99.9 백분위수(마이크로초(µs))를 나타냅니다. |
| 전송 계층의 처리량 | gRPC에서 멀티슬라이스 TPU 트래픽에 사용하는 TCP 연결의 최근 처리량 누적 분포입니다. |
grpc_tcp_delivery_rates
|
['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']
# 분포의 평균, p50, p90, p95, p99.9 백분위수(마이크로초(µs))를 나타냅니다. |
측정항목 데이터 읽기
측정항목 데이터를 읽으려면 tpumonitoring.get_metric 함수를 호출할 때 측정항목 이름을 지정합니다. 임시 측정항목 검사를 저성능 코드에 삽입하여 성능 문제가 소프트웨어 또는 하드웨어에서 비롯되었는지를 식별할 수 있습니다.
다음 코드 샘플에서는 duty_cycle 측정항목을 읽는 방법을 보여줍니다.
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
측정항목을 사용하여 TPU 사용률 확인
다음 예시에서는 TPU 모니터링 라이브러리의 측정항목을 사용하여 TPU 사용률을 추적하는 방법을 보여줍니다.
JAX 학습 중 TPU 듀티 사이클 모니터링
시나리오: JAX 학습 스크립트가 실행 중이며 학습 프로세스 전반에 걸쳐 TPU의 duty_cycle_pct 측정항목을 모니터링하여 TPU가 효과적으로 활용되고 있는지 확인하려고 합니다. 학습 중에 이 측정항목을 주기적으로 로깅하여 TPU 사용률을 추적할 수 있습니다.
다음 코드 샘플에서는 JAX 학습 중에 TPU 듀티 사이클을 모니터링하는 방법을 보여줍니다.
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
JAX 추론 실행 전에 HBM 사용률 확인
시나리오: JAX 모델로 추론을 실행하기 전에 TPU의 현재 HBM(고대역 메모리) 사용률을 확인하여 사용 가능한 메모리가 충분한지 확인하고 추론 시작 전에 기준선을 측정합니다.
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
네트워크 측정항목 확인
시나리오: 멀티 호스트 및 멀티슬라이스 워크로드가 실행 중이며 워크로드가 실행되는 동안 ssh를 사용하여 GKE 포드나 TPU 중 하나에 연결하여 네트워크 측정항목을 보려고 합니다. 명령어를 멀티 호스트 워크로드에 직접 통합할 수도 있습니다.
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and training setup goes here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
# --- details here ---
# ==============================================================================
# Metric 1: TCP Delivery Rate
# ==============================================================================
# This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)
# Get the metric object
delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)
# Print the description provided by the library
print("Description:", delivery_rate_metric.description())
# Print the actual data payload
print("Data:", delivery_rate_metric.data())
# ==============================================================================
# Metric 2: TCP Minimum Round Trip Time (RTT)
# ==============================================================================
# This metric reflects the minimum RTT measured between sending a TCP packet
# and receiving the acknowledgement.
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)
# Get the metric object
min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)
# Print the description provided by the library
print("Description:", min_rtt_metric.description())
# Print the actual data payload
print("Data:", min_rtt_metric.data())
TPU 측정항목의 새로고침 빈도
TPU 측정항목의 새로고침 빈도는 최소 1초로 제한됩니다. 호스트 측정항목 데이터를 고정 빈도(1Hz)로 내보냅니다. 이 내보내기 프로세스에서 발생하는 지연 시간을 무시할 수 있습니다. LibTPU에서 제공되는 런타임 측정항목은 이러한 빈도 제약 조건에 적용되지 않습니다. 그러나 일관성을 위해 해당 측정항목도 1Hz에 또는 초당 샘플 1개로 샘플링됩니다.
TPU-Z 모듈
TPU-Z는 TPU의 원격 분석 및 디버깅 기능입니다. 호스트에 연결된 모든 TPU 코어의 자세한 런타임 상태 정보를 제공합니다. 이 기능은 libtpu Python SDK의 libtpu.sdk 모듈에 포함된 tpuz 모듈을 통해 제공됩니다. 이 모듈은 각 코어 상태 스냅샷을 제공합니다.
TPU-Z의 기본 사용 사례는 분산 TPU 워크로드에서 중단 또는 교착 상태를 진단하는 것입니다. 호스트에서 TPU-Z 서비스를 쿼리하여 모든 코어의 상태를 캡처하고 모든 코어에서 프로그램 카운터, HLO 위치, 실행 ID를 비교하여 이상치를 식별할 수 있습니다.
libtpu.sdk 라이브러리 내에서 get_core_state_summary() 함수를 사용하여 TPU-Z 측정항목을 표시합니다.
summary = sdk.tpuz.get_core_state_summary()
TPU-Z 측정항목 출력은 딕셔너리로 제공됩니다. 다음은 단일 코어의 잘린 예시입니다.
{
"host_name": "my-tpu-host-vm",
"core_states": {
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 4490,
"program_id": 3274167277388825310,
"run_id": 3
}
],
"program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
"queued_program_info": [],
"error_message": ""
}
// ...
}
}
각 코어의 고수준 옵티마이저(HLO)에 대한 정보를 검색하려면 include_hlo_info 파라미터를 True로 설정합니다.
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
출력에는 추가 HLO 정보가 포함됩니다.
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 17776,
"tag": 3,
"tracemark": 2147483646,
"program_id": 3230481660274331500,
"run_id": 81,
"hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
"hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
}
],
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
"launch_id": 1394130914,
"queued_program_info": [
{
"run_id": 81,
"launch_id": 1394130914,
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
}
]
}
TPU-Z 측정항목
get_core_state_summary 함수는 다음 구조를 사용하여 TPU-Z 측정항목을 딕셔너리 양식으로 반환합니다.
CurrentCoreStateSummary
CurrentCoreStateSummary 딕셔너리는 개별 TPU 코어 상태에 대한 자세한 내용을 요약해서 제공합니다.
| 필드 | 유형 | 설명 |
|---|---|---|
core_id |
딕셔너리 | TPU 코어에 대한 ID 정보가 포함된 TpuCoreIdentifier 딕셔너리입니다. |
sequencer_info |
딕셔너리 목록 | 코어의 각 시퀀서 상태를 설명하는 SequencerInfo 딕셔너리 목록입니다. |
program_fingerprint |
바이트 | 이 코어에서 실행되는 프로그램의 지문입니다. |
launch_id |
정수 | 현재 또는 가장 최근 프로그램의 출시 ID입니다. |
queued_program_info |
딕셔너리 목록 | 실행 큐에 추가된 프로그램의 QueuedProgramInfo 딕셔너리 목록입니다. |
error_message |
문자열 | 이 코어의 오류 메시지입니다. |
TpuCoreIdentifier
TpuCoreIdentifier 딕셔너리는 TPU 시스템 내 코어의 ID 정보를 제공합니다.
| 필드 | 유형 | 설명 |
|---|---|---|
global_core_id |
정수 | 코어 ID입니다. |
chip_id |
정수 | 코어가 속한 칩의 ID입니다. |
core_on_chip |
딕셔너리 | 칩의 코어 유형과 색인을 설명하는 TpuCoreOnChip 딕셔너리입니다. |
TpuCoreOnChip
TpuCoreOnChip 딕셔너리에는 특정 칩에 있는 코어의 속성에 대한 정보가 포함되어 있습니다.
| 필드 | 유형 | 설명 |
|---|---|---|
type |
문자열 | TPU 코어 유형입니다. 예를 들면 TPU_CORE_TYPE_TENSOR_CORE입니다. |
index |
정수 | 칩의 코어 색인입니다. |
SequencerInfo
SequencerInfo 딕셔너리에는 코어의 단일 시퀀서 상태에 대한 정보가 포함됩니다.
| 필드 | 유형 | 설명 |
|---|---|---|
sequencer_type |
문자열 | 시퀀서 유형입니다. 예를 들면 TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER입니다. |
sequencer_index |
정수 | 시퀀서 색인입니다(같은 유형이 여러 개 있는 경우). |
pc |
정수 | 현재 프로그램 카운터 값입니다. |
program_id |
정수 | TPU 코어에서 실행하기 위해 출시되는 프로그램의 특정 인스턴스와 연결된 ID입니다. |
run_id |
정수 | TPU 코어에서 프로그램 실행 특정 인스턴스와 연결된 실행 ID입니다. |
hlo_location |
문자열 | 고수준 옵티마이저 위치 정보입니다. |
hlo_detailed_info |
문자열 | 자세한 고수준 옵티마이저 정보입니다. |
QueuedProgramInfo
QueuedProgramInfo 딕셔너리에는 코어에서 실행하기 위해 큐에 추가된 프로그램에 대한 정보가 포함되어 있습니다.
| 필드 | 유형 | 설명 |
|---|---|---|
run_id |
정수 | 큐에 추가된 프로그램의 실행 ID입니다. |
launch_id |
정수 | 큐에 추가된 프로그램의 출시 ID입니다. |
program_fingerprint |
바이트 | 큐에 추가된 프로그램의 지문입니다. |
JAX를 사용한 TPU-Z
libtpu.sdk 라이브러리를 통해 JAX 워크로드에서 TPU-Z 측정항목에 액세스할 수 있습니다.
다음 Python 스크립트는 고성능 텐서 계산에 JAX를 사용하는 동시에 백그라운드 스레드에서 libtpu SDK를 사용하여 기본 TPU 하드웨어의 상태와 활동을 모니터링합니다.
다음 Python 패키지를 포함합니다.
import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk
monitor_tpu_status 함수는 백그라운드 스레드를 사용하여 기본 애플리케이션에서 JAX 워크로드를 실행하는 동안 TPU 코어 작동 상태를 지속적으로 보여줍니다. 실시간 진단 도구 역할을 합니다.
def monitor_tpu_status():
"""Monitors TPU status in a background thread."""
while monitoring_active:
try:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
if summary and 'core_states' in summary:
print(summary)
else:
print('WARNING: Call returned an empty or invalid summary.')
except RuntimeError as e:
print(f'FAIL: Error calling API: {e}')
except Exception as e:
print(f'FAIL: Unexpected error in monitor thread: {e}')
for _ in range(MONITORING_INTERVAL_SECONDS * 2):
if not monitoring_active:
break
time.sleep(0.5)
print('✅ Monitoring thread stopped.')
transformer_block 함수는 LLM의 기본 구성요소인 Transformer 아키텍처의 완전한 레이어를 구현합니다.
@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
"""A simplified but computationally intensive Transformer block."""
# Multi-head Self-Attention
qkv = jnp.dot(x, params['qkv_kernel'])
q, k, v = jnp.array_split(qkv, 3, axis=-1)
# Reshape for multi-head attention
q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
# Scaled dot-product attention
attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)
attention_output = jnp.dot(attention_output, params['o_kernel'])
# Residual connection and Layer Normalization 1
h1 = x + attention_output
h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
h1_norm = h1_norm / jnp.sqrt(
jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
)
# Feed-Forward Network
ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])
# Residual connection and Layer Normalization 2
h2 = h1_norm + ffn_output
h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
h2_norm = h2_norm / jnp.sqrt(
jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
)
return h2_norm
main 함수는 JAX 계산 설정을 조정하고 백그라운드 TPU 모니터링을 시작하며 기본 워크로드 루프를 실행합니다.
def main():
num_devices = jax.device_count()
print(f"Running on {num_devices} devices.")
batch_size = 128 * num_devices
seq_len = 512
embed_dim = 1024
ffn_dim = embed_dim * 4
key = jax.random.PRNGKey(0)
params = {
'qkv_kernel': jax.random.normal(
key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
),
'o_kernel': jax.random.normal(
key, (embed_dim, embed_dim), dtype=jnp.bfloat16
),
'ffn1_kernel': jax.random.normal(
key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
),
'ffn2_kernel': jax.random.normal(
key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
),
}
input_data = jax.random.normal(
key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
)
input_data = jax.device_put(input_data)
monitor_thread = threading.Thread(target=monitor_tpu_status)
monitor_thread.start()
print("Starting JAX computation loop...")
start_time = time.time()
iterations = 0
while time.time() - start_time < JOB_DURATION_SECONDS:
result = transformer_block(params, input_data)
result.block_until_ready()
iterations += 1
print(f' -> Jax iteration {iterations} complete.', end='\r')
print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")
global monitoring_active
monitoring_active = False
monitor_thread.join()
if __name__ == '__main__':
main()
문제 해결
이 섹션에서는 TPU 모니터링 라이브러리를 사용하는 동안 발생할 수 있는 문제를 식별하고 해결하는 데 도움이 되는 문제 해결 정보를 제공합니다.
기능 또는 측정항목 누락
일부 기능이나 측정항목이 표시되지 않는 경우 가장 일반적인 원인은 오래된 libtpu 버전입니다. TPU 모니터링 라이브러리 기능과 측정항목은 libtpu 출시 버전에 포함되어 있으며 오래된 버전에는 새로운 기능과 측정항목이 누락되어 있을 수 있습니다.
환경에서 실행 중인 libtpu 버전을 확인합니다.
명령줄:
pip show libtpu
Python
import libtpu
print(libtpu.__version__)
libtpu의 최신 버전을 사용하고 있지 않으면 다음 명령어를 사용하여 라이브러리를 업데이트합니다.
pip install --upgrade libtpu