TPU Monitoring Library

Dapatkan insight mendalam tentang performa dan perilaku hardware Cloud TPU Anda dengan kemampuan pemantauan TPU tingkat lanjut, yang dibangun langsung di atas lapisan software dasar, LibTPU. Meskipun LibTPU mencakup driver, library jaringan, compiler XLA, dan runtime TPU untuk berinteraksi dengan TPU, dokumen ini berfokus pada TPU Monitoring Library.

TPU Monitoring Library menyediakan:

  • Observabilitas komprehensif: Dapatkan akses ke rangkaian metrik dan API telemetri, yang memberikan insight mendetail tentang performa operasional dan perilaku spesifik TPU Anda.

  • Toolkit diagnostik: Menyediakan SDK dan antarmuka command line (CLI) yang dirancang untuk memungkinkan proses debug dan analisis performa mendalam pada resource TPU Anda.

Fitur pemantauan ini dirancang sebagai solusi tingkat teratas yang ditujukan untuk pelanggan, sehingga memberi Anda alat penting untuk mengoptimalkan beban kerja TPU secara efektif.

TPU Monitoring Library memberi Anda informasi mendetail tentang performa workload machine learning pada hardware TPU. Alat ini dirancang untuk membantu Anda memahami penggunaan TPU, mengidentifikasi hambatan, dan men-debug masalah performa. Metrik ini memberi Anda informasi yang lebih mendetail daripada metrik gangguan, metrik goodput, dan metrik lainnya.

Mulai menggunakan TPU Monitoring Library

Mengakses insight yang efektif ini sangatlah mudah. Fungsi pemantauan TPU diintegrasikan dengan LibTPU SDK, sehingga fungsi tersebut disertakan saat Anda menginstal LibTPU.

Instal LibTPU

pip install libtpu

Atau, update LibTPU dikoordinasikan dengan rilis JAX, yang berarti bahwa saat Anda menginstal rilis JAX terbaru (dirilis setiap bulan), Anda biasanya akan menggunakan versi LibTPU yang kompatibel dan fiturnya yang terbaru.

Menginstal JAX

pip install -U "jax[tpu]"

Untuk pengguna PyTorch, menginstal PyTorch/XLA akan memberikan fungsi pemantauan TPU dan LibTPU terbaru.

Menginstal PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Untuk mengetahui informasi selengkapnya tentang cara menginstal PyTorch/XLA, lihat Penginstalan di repositori GitHub PyTorch/XLA.

Mengimpor library di Python

Untuk mulai menggunakan TPU Monitoring Library, Anda perlu mengimpor modul libtpu dalam kode Python Anda.

from libtpu.sdk import tpumonitoring

Mencantumkan semua fungsi yang didukung

Mencantumkan semua nama metrik dan fungsi yang didukungnya:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

Metrik yang didukung

Contoh kode berikut menunjukkan cara mencantumkan semua nama metrik yang didukung:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

Tabel berikut menampilkan semua metrik dan definisi yang sesuai:

Metrik Definisi Nama metrik untuk API Contoh nilai
Penggunaan Tensor Core Mengukur persentase penggunaan TensorCore Anda, yang dihitung sebagai persentase operasi yang merupakan bagian dari operasi TensorCore. Sampel diambil 10 mikrodetik setiap 1 detik. Anda tidak dapat mengubah frekuensi pengambilan sampel. Metrik ini memungkinkan Anda memantau efisiensi beban kerja di perangkat TPU. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# persentase pemanfaatan untuk ID akselerator 0-3.
Persentase Siklus Tugas Persentase waktu selama periode sampel terakhir (setiap 5 detik; dapat disesuaikan dengan menetapkan flag LIBTPU_INIT_ARG) saat akselerator secara aktif memproses (direkam dengan siklus yang digunakan untuk mengeksekusi program HLO selama periode pengambilan sampel terakhir). Metrik ini menunjukkan seberapa sibuk TPU. Metrik ini dikeluarkan per chip. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Persentase siklus tugas untuk ID akselerator 0-3.
Total Kapasitas HBM Metrik ini melaporkan total kapasitas HBM dalam byte. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Total kapasitas HBM dalam byte yang terpasang ke ID akselerator 0-3.
Penggunaan Kapasitas HBM Metrik ini melaporkan penggunaan kapasitas HBM dalam byte selama periode pengambilan sampel terakhir (setiap 5 detik; dapat disesuaikan dengan menyetel flag LIBTPU_INIT_ARG). hbm_capacity_usage ['100', '200', '300', '400']

# Penggunaan kapasitas untuk HBM dalam byte yang terhubung ke ID akselerator 0-3.
Latensi transfer buffer Latensi transfer jaringan untuk traffic multislice berskala besar. Visualisasi ini memungkinkan Anda memahami lingkungan performa jaringan secara keseluruhan. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# ukuran buffer, rata-rata, p50, p90, p99, p99,9 dari distribusi latensi transfer jaringan.
Metrik Distribusi Waktu Eksekusi Operasi Tingkat Tinggi Memberikan insight performa terperinci tentang status eksekusi biner yang dikompilasi HLO, sehingga memungkinkan deteksi regresi dan proses debug tingkat model. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# Distribusi durasi waktu eksekusi HLO untuk CoreType-CoreID dengan mean, p50, p90, p95, p999.
Ukuran antrean Pengoptimal Tingkat Tinggi Pemantauan ukuran antrean eksekusi HLO melacak jumlah program HLO yang dikompilasi yang menunggu atau sedang dieksekusi. Metrik ini mengungkapkan kemacetan pipeline eksekusi, sehingga memungkinkan identifikasi bottleneck performa dalam eksekusi hardware, overhead driver, atau alokasi resource. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Mengukur ukuran antrean untuk CoreType-CoreID.
Latensi End-to-End Kolektif Metrik ini mengukur latensi kolektif end-to-end melalui DCN dalam mikrodetik, dari host yang memulai operasi hingga semua peer yang menerima output. Hal ini mencakup pengurangan data sisi host dan pengiriman output ke TPU. Hasilnya adalah string yang menjelaskan ukuran buffer, jenis, dan latensi rata-rata, p50, p90, p95, dan p99,9. collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Operasi kolektif ukuran transfer, rata-rata, p50, p90, p95, p999 latensi end-to-end kolektif.
Latensi Round Trip di Lapisan Transport Distribusi Round Trip Time (RTT) minimum yang diamati pada koneksi TCP yang digunakan oleh gRPC untuk traffic TPU multislice. grpc_tcp_min_round_trip_times ['27.63, 29.03, 38.52, 41.63, 52.74']

# Mewakili persentil rata-rata, p50, p90, p95, dan p99,9 distribusi dalam mikrodetik (µs).
Throughput di Lapisan Transport Distribusi kumulatif throughput terbaru koneksi TCP yang digunakan oleh gRPC untuk traffic TPU multislice. grpc_tcp_delivery_rates ['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']

# Mewakili persentil rata-rata, p50, p90, p95, dan p99,9 distribusi dalam mikrodetik (µs).

Membaca data metrik

Untuk membaca data metrik, tentukan nama metrik saat Anda memanggil fungsi tpumonitoring.get_metric. Anda dapat menyisipkan pemeriksaan metrik ad hoc ke dalam kode berperforma rendah untuk mengidentifikasi apakah masalah performa berasal dari software atau hardware.

Contoh kode berikut menunjukkan cara membaca metrik duty_cycle:

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

Menggunakan metrik untuk memeriksa pemanfaatan TPU

Contoh berikut menunjukkan cara menggunakan metrik dari TPU Monitoring Library untuk melacak pemanfaatan TPU.

Memantau siklus tugas TPU selama pelatihan JAX

Skenario: Anda menjalankan skrip pelatihan JAX dan ingin memantau metrik duty_cycle_pct TPU selama proses pelatihan untuk mengonfirmasi bahwa TPU Anda digunakan secara efektif. Anda dapat mencatat metrik ini secara berkala selama pelatihan untuk melacak pemakaian TPU.

Contoh kode berikut menunjukkan cara memantau Siklus Tugas TPU selama pelatihan JAX:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

Memeriksa pemanfaatan HBM sebelum menjalankan inferensi JAX

Skenario: Sebelum menjalankan inferensi dengan model JAX, periksa penggunaan HBM (High Bandwidth Memory) saat ini di TPU untuk mengonfirmasi bahwa Anda memiliki memori yang cukup dan untuk mendapatkan pengukuran dasar sebelum inferensi dimulai.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

Memeriksa metrik jaringan

Skenario: Anda menjalankan workload multi-host dan multislice dan ingin terhubung ke salah satu pod atau TPU GKE menggunakan SSH untuk melihat metrik jaringan saat workload sedang berjalan. Perintah juga dapat digabungkan langsung ke beban kerja multi-host.

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

 # --- Your JAX model and training setup goes here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)
     # --- details here ---

    # ==============================================================================
    # Metric 1: TCP Delivery Rate
    # ==============================================================================
    # This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)

        # Get the metric object
        delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)

        # Print the description provided by the library
        print("Description:", delivery_rate_metric.description())

        # Print the actual data payload
        print("Data:", delivery_rate_metric.data())

    # ==============================================================================
    # Metric 2: TCP Minimum Round Trip Time (RTT)
    # ==============================================================================
    # This metric reflects the minimum RTT measured between sending a TCP packet
    # and receiving the acknowledgement.
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)

        # Get the metric object
        min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)

        # Print the description provided by the library
        print("Description:", min_rtt_metric.description())

        # Print the actual data payload
        print("Data:", min_rtt_metric.data())

Frekuensi pembaruan metrik TPU

Frekuensi pemuatan ulang metrik TPU dibatasi minimal satu detik. Data metrik host diekspor pada frekuensi tetap 1 Hz. Latensi yang disebabkan oleh proses ekspor ini dapat diabaikan. Metrik runtime dari LibTPU tidak tunduk pada batasan frekuensi yang sama. Namun, agar konsisten, metrik ini juga diambil sampelnya pada 1 Hz atau 1 sampel per detik.

Modul TPU-Z

TPU-Z adalah fasilitas telemetri dan proses debug untuk TPU. Alat ini memberikan informasi status runtime mendetail untuk semua core TPU yang terhubung ke host. Fungsi ini disediakan melalui modul tpuz, yang merupakan bagian dari modul libtpu.sdk di libtpu Python SDK. Modul ini memberikan snapshot status setiap core.

Kasus penggunaan utama TPU-Z adalah mendiagnosis hang atau kebuntuan dalam workload TPU terdistribusi. Anda dapat membuat kueri layanan TPU-Z di host untuk merekam status setiap core, membandingkan Penghitung Program, lokasi HLO, dan ID Run di semua core untuk mengidentifikasi anomali.

Gunakan fungsi get_core_state_summary() dalam library libtpu.sdk untuk menampilkan metrik TPU-Z:

summary = sdk.tpuz.get_core_state_summary()

Output untuk metrik TPU-Z diberikan sebagai kamus. Berikut adalah contoh yang dipangkas untuk satu core:

{
  "host_name": "my-tpu-host-vm",
  "core_states": {
    "1": {
      "core_id": {
        "global_core_id": 1,
        "chip_id": 0,
        "core_on_chip": {
          "type": "TPU_CORE_TYPE_TENSOR_CORE",
          "index": 1
        }
      },
      "sequencer_info": [
        {
          "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
          "sequencer_index": 0,
          "pc": 4490,
          "program_id": 3274167277388825310,
          "run_id": 3
        }
      ],
      "program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
      "queued_program_info": [],
      "error_message": ""
    }
    // ...
  }
}

Untuk mengambil informasi tentang Pengoptimal Tingkat Tinggi (HLO) di setiap core, tetapkan parameter include_hlo_info ke True:

summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)

Output mencakup Informasi HLO tambahan:

"1": {
    "core_id": {
      "global_core_id": 1,
      "chip_id": 0,
      "core_on_chip": {
        "type": "TPU_CORE_TYPE_TENSOR_CORE",
        "index": 1
      }
    },
    "sequencer_info": [
      {
        "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
        "sequencer_index": 0,
        "pc": 17776,
        "tag": 3,
        "tracemark": 2147483646,
        "program_id": 3230481660274331500,
        "run_id": 81,
        "hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
        "hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
      }
    ],
    "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
    "launch_id": 1394130914,
    "queued_program_info": [
      {
        "run_id": 81,
        "launch_id": 1394130914,
        "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
      }
    ]
  }

Metrik TPU-Z

Fungsi get_core_state_summary menampilkan metrik TPU-Z dalam bentuk kamus dengan struktur berikut.

CurrentCoreStateSummary

Kamus CurrentCoreStateSummary memberikan ringkasan mendetail tentang status setiap core TPU.

Kolom Jenis Deskripsi
core_id kamus Kamus TpuCoreIdentifier yang berisi informasi ID tentang core TPU.
sequencer_info daftar kamus Daftar kamus SequencerInfo, yang menjelaskan status setiap pengurut di core.
program_fingerprint byte Sidik jari program yang dijalankan di core ini.
launch_id bilangan bulat ID peluncuran program saat ini atau terbaru.
queued_program_info daftar kamus Daftar kamus QueuedProgramInfo untuk program yang diantrekan untuk dieksekusi.
error_message string Pesan error apa pun untuk inti ini.

TpuCoreIdentifier

Kamus TpuCoreIdentifier memberikan informasi ID untuk core dalam sistem TPU.

Kolom Jenis Deskripsi
global_core_id bilangan bulat ID inti.
chip_id bilangan bulat ID chip tempat inti berada.
core_on_chip kamus Kamus TpuCoreOnChip yang menjelaskan jenis inti dan indeksnya pada chip.

TpuCoreOnChip

Kamus TpuCoreOnChip berisi informasi tentang properti inti dalam chip tertentu.

Kolom Jenis Deskripsi
type string Jenis core TPU. Misalnya: TPU_CORE_TYPE_TENSOR_CORE.
index bilangan bulat Indeks inti pada chip.

SequencerInfo

Kamus SequencerInfo berisi informasi tentang status satu pengurut pada core.

Kolom Jenis Deskripsi
sequencer_type string Jenis pengurut. Misalnya: TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER.
sequencer_index bilangan bulat Indeks sequencer (jika ada beberapa jenis yang sama).
pc bilangan bulat Nilai Program Counter saat ini.
program_id bilangan bulat ID yang terkait dengan instance program tertentu yang diluncurkan untuk dieksekusi pada core TPU.
run_id bilangan bulat ID Run yang terkait dengan instance spesifik eksekusi program pada core TPU.
hlo_location string Informasi lokasi Pengoptimal Tingkat Tinggi.
hlo_detailed_info string Informasi Pengoptimal Tingkat Tinggi yang mendetail.

QueuedProgramInfo

Kamus QueuedProgramInfo berisi informasi tentang program yang diantrekan untuk dieksekusi di core.

Kolom Jenis Deskripsi
run_id bilangan bulat ID Run untuk program yang diantrekan.
launch_id bilangan bulat ID Peluncuran untuk program yang diantrekan.
program_fingerprint byte Sidik jari program yang diantrekan.

TPU-Z dengan JAX

Anda dapat mengakses metrik TPU-Z dalam beban kerja JAX melalui library libtpu.sdk. Skrip Python berikut menggunakan JAX untuk komputasi tensor berperforma tinggi, sekaligus menggunakan libtpu SDK di thread latar belakang untuk memantau status dan aktivitas hardware TPU yang mendasarinya.

Sertakan paket Python berikut:

import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk

Fungsi monitor_tpu_status menggunakan thread latar belakang untuk terus menampilkan status operasional core TPU saat aplikasi utama menjalankan beban kerja JAX. Alat ini berfungsi sebagai alat diagnostik real-time.

def monitor_tpu_status():
  """Monitors TPU status in a background thread."""

  while monitoring_active:
    try:
      summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
      if summary and 'core_states' in summary:
        print(summary)
      else:
        print('WARNING: Call returned an empty or invalid summary.')
    except RuntimeError as e:
      print(f'FAIL: Error calling API: {e}')
    except Exception as e:
      print(f'FAIL: Unexpected error in monitor thread: {e}')

    for _ in range(MONITORING_INTERVAL_SECONDS * 2):
      if not monitoring_active:
        break
      time.sleep(0.5)
  print('✅ Monitoring thread stopped.')

Fungsi transformer_block mengimplementasikan lapisan lengkap arsitektur Transformer, yang merupakan elemen dasar penyusun LLM.

@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
  """A simplified but computationally intensive Transformer block."""
  # Multi-head Self-Attention
  qkv = jnp.dot(x, params['qkv_kernel'])
  q, k, v = jnp.array_split(qkv, 3, axis=-1)

  # Reshape for multi-head attention
  q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)

  # Scaled dot-product attention
  attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
  attention_weights = jax.nn.softmax(attention_scores, axis=-1)
  attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
  attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)

  attention_output = jnp.dot(attention_output, params['o_kernel'])

  # Residual connection and Layer Normalization 1
  h1 = x + attention_output
  h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
  h1_norm = h1_norm / jnp.sqrt(
      jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
  )

  # Feed-Forward Network
  ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
  ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])

  # Residual connection and Layer Normalization 2
  h2 = h1_norm + ffn_output
  h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
  h2_norm = h2_norm / jnp.sqrt(
      jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
  )

  return h2_norm

Fungsi main mengorkestrasi penyiapan komputasi JAX, memulai pemantauan TPU di latar belakang, dan menjalankan loop workload utama.

def main():
  num_devices = jax.device_count()
  print(f"Running on {num_devices} devices.")

  batch_size = 128 * num_devices
  seq_len = 512
  embed_dim = 1024
  ffn_dim = embed_dim * 4

  key = jax.random.PRNGKey(0)

  params = {
      'qkv_kernel': jax.random.normal(
          key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
      ),
      'o_kernel': jax.random.normal(
          key, (embed_dim, embed_dim), dtype=jnp.bfloat16
      ),
      'ffn1_kernel': jax.random.normal(
          key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
      ),
      'ffn2_kernel': jax.random.normal(
          key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
      ),
  }
  input_data = jax.random.normal(
      key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
  )
  input_data = jax.device_put(input_data)
  monitor_thread = threading.Thread(target=monitor_tpu_status)
  monitor_thread.start()
  print("Starting JAX computation loop...")
  start_time = time.time()
  iterations = 0
  while time.time() - start_time < JOB_DURATION_SECONDS:
    result = transformer_block(params, input_data)
    result.block_until_ready()
    iterations += 1
    print(f'  -> Jax iteration {iterations} complete.', end='\r')

  print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")

  global monitoring_active
  monitoring_active = False
  monitor_thread.join()

if __name__ == '__main__':
  main()

Pemecahan masalah

Bagian ini memberikan informasi pemecahan masalah untuk membantu Anda mengidentifikasi dan menyelesaikan masalah yang mungkin Anda alami saat menggunakan TPU Monitoring Library.

Fitur atau metrik tidak ada

Jika Anda tidak dapat melihat beberapa fitur atau metrik, penyebab paling umum adalah versi libtpu yang sudah tidak berlaku. Fitur dan metrik TPU Monitoring Library disertakan dalam rilis libtpu, dan versi yang sudah tidak berlaku mungkin tidak memiliki fitur dan metrik baru.

Periksa versi libtpu yang berjalan di lingkungan Anda:

Command line:

pip show libtpu

Python:

import libtpu

print(libtpu.__version__)

Jika Anda tidak menggunakan versi terbaru libtpu, gunakan perintah berikut untuk mengupdate library:

pip install --upgrade libtpu