TPU Monitoring Library
Dapatkan insight mendalam tentang performa dan perilaku hardware Cloud TPU Anda dengan kemampuan pemantauan TPU tingkat lanjut, yang dibangun langsung di atas lapisan software dasar, LibTPU. Meskipun LibTPU mencakup driver, library jaringan, compiler XLA, dan runtime TPU untuk berinteraksi dengan TPU, dokumen ini berfokus pada TPU Monitoring Library.
TPU Monitoring Library menyediakan:
Observabilitas komprehensif: Dapatkan akses ke rangkaian metrik dan API telemetri, yang memberikan insight mendetail tentang performa operasional dan perilaku spesifik TPU Anda.
Toolkit diagnostik: Menyediakan SDK dan antarmuka command line (CLI) yang dirancang untuk memungkinkan proses debug dan analisis performa mendalam pada resource TPU Anda.
Fitur pemantauan ini dirancang sebagai solusi tingkat teratas yang ditujukan untuk pelanggan, sehingga memberi Anda alat penting untuk mengoptimalkan beban kerja TPU secara efektif.
TPU Monitoring Library memberi Anda informasi mendetail tentang performa workload machine learning pada hardware TPU. Alat ini dirancang untuk membantu Anda memahami penggunaan TPU, mengidentifikasi hambatan, dan men-debug masalah performa. Metrik ini memberi Anda informasi yang lebih mendetail daripada metrik gangguan, metrik goodput, dan metrik lainnya.
Mulai menggunakan TPU Monitoring Library
Mengakses insight yang efektif ini sangatlah mudah. Fungsi pemantauan TPU diintegrasikan dengan LibTPU SDK, sehingga fungsi tersebut disertakan saat Anda menginstal LibTPU.
Instal LibTPU
pip install libtpu
Atau, update LibTPU dikoordinasikan dengan rilis JAX, yang berarti bahwa saat Anda menginstal rilis JAX terbaru (dirilis setiap bulan), Anda biasanya akan menggunakan versi LibTPU yang kompatibel dan fiturnya yang terbaru.
Menginstal JAX
pip install -U "jax[tpu]"
Untuk pengguna PyTorch, menginstal PyTorch/XLA akan memberikan fungsi pemantauan TPU dan LibTPU terbaru.
Menginstal PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Untuk mengetahui informasi selengkapnya tentang cara menginstal PyTorch/XLA, lihat Penginstalan di repositori GitHub PyTorch/XLA.
Mengimpor library di Python
Untuk mulai menggunakan TPU Monitoring Library, Anda perlu mengimpor modul libtpu
dalam kode Python Anda.
from libtpu.sdk import tpumonitoring
Mencantumkan semua fungsi yang didukung
Mencantumkan semua nama metrik dan fungsi yang didukungnya:
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
Metrik yang didukung
Contoh kode berikut menunjukkan cara mencantumkan semua nama metrik yang didukung:
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
Tabel berikut menampilkan semua metrik dan definisi yang sesuai:
| Metrik | Definisi | Nama metrik untuk API | Contoh nilai |
|---|---|---|---|
| Penggunaan Tensor Core | Mengukur persentase penggunaan TensorCore Anda, yang dihitung sebagai persentase operasi yang merupakan bagian dari operasi TensorCore. Sampel diambil 10 mikrodetik setiap 1 detik. Anda tidak dapat mengubah frekuensi pengambilan sampel. Metrik ini memungkinkan Anda memantau efisiensi beban kerja di perangkat TPU. |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# persentase pemanfaatan untuk ID akselerator 0-3. |
| Persentase Siklus Tugas | Persentase waktu selama periode sampel terakhir (setiap 5 detik; dapat disesuaikan dengan menetapkan flag LIBTPU_INIT_ARG) saat akselerator secara aktif memproses (direkam dengan siklus yang digunakan untuk mengeksekusi program HLO selama periode pengambilan sampel terakhir). Metrik ini
menunjukkan seberapa sibuk TPU. Metrik ini dikeluarkan per chip.
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# Persentase siklus tugas untuk ID akselerator 0-3. |
| Total Kapasitas HBM | Metrik ini melaporkan total kapasitas HBM dalam byte. |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# Total kapasitas HBM dalam byte yang terpasang ke ID akselerator 0-3. |
| Penggunaan Kapasitas HBM | Metrik ini melaporkan penggunaan kapasitas HBM dalam byte selama
periode pengambilan sampel terakhir (setiap 5 detik; dapat disesuaikan dengan menyetel
flag LIBTPU_INIT_ARG).
|
hbm_capacity_usage
|
['100', '200', '300', '400']
# Penggunaan kapasitas untuk HBM dalam byte yang terhubung ke ID akselerator 0-3. |
| Latensi transfer buffer | Latensi transfer jaringan untuk traffic multislice berskala besar. Visualisasi ini memungkinkan Anda memahami lingkungan performa jaringan secara keseluruhan. |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# ukuran buffer, rata-rata, p50, p90, p99, p99,9 dari distribusi latensi transfer jaringan. |
| Metrik Distribusi Waktu Eksekusi Operasi Tingkat Tinggi | Memberikan insight performa terperinci tentang status eksekusi biner yang dikompilasi HLO, sehingga memungkinkan deteksi regresi dan proses debug tingkat model. |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# Distribusi durasi waktu eksekusi HLO untuk CoreType-CoreID dengan mean, p50, p90, p95, p999. |
| Ukuran antrean Pengoptimal Tingkat Tinggi | Pemantauan ukuran antrean eksekusi HLO melacak jumlah program HLO yang dikompilasi yang menunggu atau sedang dieksekusi. Metrik ini mengungkapkan kemacetan pipeline eksekusi, sehingga memungkinkan identifikasi bottleneck performa dalam eksekusi hardware, overhead driver, atau alokasi resource. |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# Mengukur ukuran antrean untuk CoreType-CoreID. |
| Latensi End-to-End Kolektif | Metrik ini mengukur latensi kolektif end-to-end melalui DCN dalam mikrodetik, dari host yang memulai operasi hingga semua peer yang menerima output. Hal ini mencakup pengurangan data sisi host dan pengiriman output ke TPU. Hasilnya adalah string yang menjelaskan ukuran buffer, jenis, dan latensi rata-rata, p50, p90, p95, dan p99,9. |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# Operasi kolektif ukuran transfer, rata-rata, p50, p90, p95, p999 latensi end-to-end kolektif. |
| Latensi Round Trip di Lapisan Transport | Distribusi Round Trip Time (RTT) minimum yang diamati pada koneksi TCP yang digunakan oleh gRPC untuk traffic TPU multislice. |
grpc_tcp_min_round_trip_times
|
['27.63, 29.03, 38.52, 41.63, 52.74']
# Mewakili persentil rata-rata, p50, p90, p95, dan p99,9 distribusi dalam mikrodetik (µs). |
| Throughput di Lapisan Transport | Distribusi kumulatif throughput terbaru koneksi TCP yang digunakan oleh gRPC untuk traffic TPU multislice. |
grpc_tcp_delivery_rates
|
['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']
# Mewakili persentil rata-rata, p50, p90, p95, dan p99,9 distribusi dalam mikrodetik (µs). |
Membaca data metrik
Untuk membaca data metrik, tentukan nama metrik saat Anda memanggil fungsi
tpumonitoring.get_metric. Anda dapat menyisipkan pemeriksaan metrik ad hoc ke dalam
kode berperforma rendah untuk mengidentifikasi apakah masalah performa berasal dari software
atau hardware.
Contoh kode berikut menunjukkan cara membaca metrik duty_cycle:
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
Menggunakan metrik untuk memeriksa pemanfaatan TPU
Contoh berikut menunjukkan cara menggunakan metrik dari TPU Monitoring Library untuk melacak pemanfaatan TPU.
Memantau siklus tugas TPU selama pelatihan JAX
Skenario: Anda menjalankan skrip pelatihan JAX dan ingin memantau metrik duty_cycle_pct TPU selama proses pelatihan untuk mengonfirmasi bahwa TPU Anda digunakan secara efektif. Anda dapat mencatat metrik ini secara berkala selama pelatihan untuk melacak pemakaian TPU.
Contoh kode berikut menunjukkan cara memantau Siklus Tugas TPU selama pelatihan JAX:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
Memeriksa pemanfaatan HBM sebelum menjalankan inferensi JAX
Skenario: Sebelum menjalankan inferensi dengan model JAX, periksa penggunaan HBM (High Bandwidth Memory) saat ini di TPU untuk mengonfirmasi bahwa Anda memiliki memori yang cukup dan untuk mendapatkan pengukuran dasar sebelum inferensi dimulai.
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
Memeriksa metrik jaringan
Skenario: Anda menjalankan workload multi-host dan multislice dan ingin terhubung ke salah satu pod atau TPU GKE menggunakan SSH untuk melihat metrik jaringan saat workload sedang berjalan. Perintah juga dapat digabungkan langsung ke beban kerja multi-host.
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and training setup goes here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
# --- details here ---
# ==============================================================================
# Metric 1: TCP Delivery Rate
# ==============================================================================
# This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)
# Get the metric object
delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)
# Print the description provided by the library
print("Description:", delivery_rate_metric.description())
# Print the actual data payload
print("Data:", delivery_rate_metric.data())
# ==============================================================================
# Metric 2: TCP Minimum Round Trip Time (RTT)
# ==============================================================================
# This metric reflects the minimum RTT measured between sending a TCP packet
# and receiving the acknowledgement.
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)
# Get the metric object
min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)
# Print the description provided by the library
print("Description:", min_rtt_metric.description())
# Print the actual data payload
print("Data:", min_rtt_metric.data())
Frekuensi pembaruan metrik TPU
Frekuensi pemuatan ulang metrik TPU dibatasi minimal satu detik. Data metrik host diekspor pada frekuensi tetap 1 Hz. Latensi yang disebabkan oleh proses ekspor ini dapat diabaikan. Metrik runtime dari LibTPU tidak tunduk pada batasan frekuensi yang sama. Namun, agar konsisten, metrik ini juga diambil sampelnya pada 1 Hz atau 1 sampel per detik.
Modul TPU-Z
TPU-Z adalah fasilitas telemetri dan proses debug untuk TPU. Alat ini memberikan informasi status runtime mendetail untuk semua core TPU yang terhubung ke host. Fungsi ini disediakan melalui modul tpuz, yang merupakan bagian dari modul libtpu.sdk di libtpu Python SDK. Modul ini memberikan snapshot
status setiap core.
Kasus penggunaan utama TPU-Z adalah mendiagnosis hang atau kebuntuan dalam workload TPU terdistribusi. Anda dapat membuat kueri layanan TPU-Z di host untuk merekam status setiap core, membandingkan Penghitung Program, lokasi HLO, dan ID Run di semua core untuk mengidentifikasi anomali.
Gunakan fungsi get_core_state_summary() dalam library libtpu.sdk untuk menampilkan metrik TPU-Z:
summary = sdk.tpuz.get_core_state_summary()
Output untuk metrik TPU-Z diberikan sebagai kamus. Berikut adalah contoh yang dipangkas untuk satu core:
{
"host_name": "my-tpu-host-vm",
"core_states": {
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 4490,
"program_id": 3274167277388825310,
"run_id": 3
}
],
"program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
"queued_program_info": [],
"error_message": ""
}
// ...
}
}
Untuk mengambil informasi tentang Pengoptimal Tingkat Tinggi (HLO) di setiap core, tetapkan parameter include_hlo_info ke True:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
Output mencakup Informasi HLO tambahan:
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 17776,
"tag": 3,
"tracemark": 2147483646,
"program_id": 3230481660274331500,
"run_id": 81,
"hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
"hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
}
],
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
"launch_id": 1394130914,
"queued_program_info": [
{
"run_id": 81,
"launch_id": 1394130914,
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
}
]
}
Metrik TPU-Z
Fungsi get_core_state_summary menampilkan metrik TPU-Z dalam bentuk
kamus dengan struktur berikut.
CurrentCoreStateSummary
Kamus CurrentCoreStateSummary memberikan ringkasan mendetail tentang status setiap core TPU.
| Kolom | Jenis | Deskripsi |
|---|---|---|
core_id |
kamus | Kamus TpuCoreIdentifier yang berisi informasi ID tentang core TPU. |
sequencer_info |
daftar kamus | Daftar kamus SequencerInfo, yang menjelaskan status setiap pengurut di core. |
program_fingerprint |
byte | Sidik jari program yang dijalankan di core ini. |
launch_id |
bilangan bulat | ID peluncuran program saat ini atau terbaru. |
queued_program_info |
daftar kamus | Daftar kamus QueuedProgramInfo untuk program yang diantrekan untuk dieksekusi. |
error_message |
string | Pesan error apa pun untuk inti ini. |
TpuCoreIdentifier
Kamus TpuCoreIdentifier memberikan informasi ID untuk core dalam sistem TPU.
| Kolom | Jenis | Deskripsi |
|---|---|---|
global_core_id |
bilangan bulat | ID inti. |
chip_id |
bilangan bulat | ID chip tempat inti berada. |
core_on_chip |
kamus | Kamus TpuCoreOnChip yang menjelaskan jenis inti dan indeksnya pada chip. |
TpuCoreOnChip
Kamus TpuCoreOnChip berisi informasi tentang properti inti
dalam chip tertentu.
| Kolom | Jenis | Deskripsi |
|---|---|---|
type |
string | Jenis core TPU. Misalnya: TPU_CORE_TYPE_TENSOR_CORE. |
index |
bilangan bulat | Indeks inti pada chip. |
SequencerInfo
Kamus SequencerInfo berisi informasi tentang status satu
pengurut pada core.
| Kolom | Jenis | Deskripsi |
|---|---|---|
sequencer_type |
string | Jenis pengurut. Misalnya: TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER. |
sequencer_index |
bilangan bulat | Indeks sequencer (jika ada beberapa jenis yang sama). |
pc |
bilangan bulat | Nilai Program Counter saat ini. |
program_id |
bilangan bulat | ID yang terkait dengan instance program tertentu yang diluncurkan untuk dieksekusi pada core TPU. |
run_id |
bilangan bulat | ID Run yang terkait dengan instance spesifik eksekusi program pada core TPU. |
hlo_location |
string | Informasi lokasi Pengoptimal Tingkat Tinggi. |
hlo_detailed_info |
string | Informasi Pengoptimal Tingkat Tinggi yang mendetail. |
QueuedProgramInfo
Kamus QueuedProgramInfo berisi informasi tentang program yang diantrekan
untuk dieksekusi di core.
| Kolom | Jenis | Deskripsi |
|---|---|---|
run_id |
bilangan bulat | ID Run untuk program yang diantrekan. |
launch_id |
bilangan bulat | ID Peluncuran untuk program yang diantrekan. |
program_fingerprint |
byte | Sidik jari program yang diantrekan. |
TPU-Z dengan JAX
Anda dapat mengakses metrik TPU-Z dalam beban kerja JAX melalui library libtpu.sdk.
Skrip Python berikut menggunakan JAX untuk komputasi tensor berperforma tinggi,
sekaligus menggunakan libtpu SDK di thread latar belakang untuk memantau
status dan aktivitas hardware TPU yang mendasarinya.
Sertakan paket Python berikut:
import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk
Fungsi monitor_tpu_status menggunakan thread latar belakang untuk terus menampilkan status operasional core TPU saat aplikasi utama menjalankan beban kerja JAX. Alat ini berfungsi sebagai alat diagnostik real-time.
def monitor_tpu_status():
"""Monitors TPU status in a background thread."""
while monitoring_active:
try:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
if summary and 'core_states' in summary:
print(summary)
else:
print('WARNING: Call returned an empty or invalid summary.')
except RuntimeError as e:
print(f'FAIL: Error calling API: {e}')
except Exception as e:
print(f'FAIL: Unexpected error in monitor thread: {e}')
for _ in range(MONITORING_INTERVAL_SECONDS * 2):
if not monitoring_active:
break
time.sleep(0.5)
print('✅ Monitoring thread stopped.')
Fungsi transformer_block mengimplementasikan lapisan lengkap arsitektur Transformer, yang merupakan elemen dasar penyusun LLM.
@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
"""A simplified but computationally intensive Transformer block."""
# Multi-head Self-Attention
qkv = jnp.dot(x, params['qkv_kernel'])
q, k, v = jnp.array_split(qkv, 3, axis=-1)
# Reshape for multi-head attention
q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
# Scaled dot-product attention
attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)
attention_output = jnp.dot(attention_output, params['o_kernel'])
# Residual connection and Layer Normalization 1
h1 = x + attention_output
h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
h1_norm = h1_norm / jnp.sqrt(
jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
)
# Feed-Forward Network
ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])
# Residual connection and Layer Normalization 2
h2 = h1_norm + ffn_output
h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
h2_norm = h2_norm / jnp.sqrt(
jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
)
return h2_norm
Fungsi main mengorkestrasi penyiapan komputasi JAX, memulai
pemantauan TPU di latar belakang, dan menjalankan loop workload utama.
def main():
num_devices = jax.device_count()
print(f"Running on {num_devices} devices.")
batch_size = 128 * num_devices
seq_len = 512
embed_dim = 1024
ffn_dim = embed_dim * 4
key = jax.random.PRNGKey(0)
params = {
'qkv_kernel': jax.random.normal(
key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
),
'o_kernel': jax.random.normal(
key, (embed_dim, embed_dim), dtype=jnp.bfloat16
),
'ffn1_kernel': jax.random.normal(
key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
),
'ffn2_kernel': jax.random.normal(
key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
),
}
input_data = jax.random.normal(
key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
)
input_data = jax.device_put(input_data)
monitor_thread = threading.Thread(target=monitor_tpu_status)
monitor_thread.start()
print("Starting JAX computation loop...")
start_time = time.time()
iterations = 0
while time.time() - start_time < JOB_DURATION_SECONDS:
result = transformer_block(params, input_data)
result.block_until_ready()
iterations += 1
print(f' -> Jax iteration {iterations} complete.', end='\r')
print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")
global monitoring_active
monitoring_active = False
monitor_thread.join()
if __name__ == '__main__':
main()
Pemecahan masalah
Bagian ini memberikan informasi pemecahan masalah untuk membantu Anda mengidentifikasi dan menyelesaikan masalah yang mungkin Anda alami saat menggunakan TPU Monitoring Library.
Fitur atau metrik tidak ada
Jika Anda tidak dapat melihat beberapa fitur atau metrik, penyebab paling umum adalah versi libtpu yang sudah tidak berlaku. Fitur dan metrik TPU Monitoring Library disertakan dalam rilis libtpu, dan versi yang sudah tidak berlaku mungkin tidak memiliki fitur dan metrik baru.
Periksa versi libtpu yang berjalan di lingkungan Anda:
Command line:
pip show libtpu
Python:
import libtpu
print(libtpu.__version__)
Jika Anda tidak menggunakan versi terbaru
libtpu, gunakan perintah berikut untuk mengupdate library:
pip install --upgrade libtpu