Libreria di monitoraggio TPU
Ottieni informazioni approfondite sulle prestazioni e sul comportamento dell'hardware Cloud TPU con funzionalità di monitoraggio avanzate di TPU, basate direttamente sul livello software di base, LibTPU. Sebbene LibTPU comprenda driver, librerie di rete, il compilatore XLA e il runtime TPU per l'interazione con le TPU, questo documento si concentra sulla libreria di monitoraggio TPU.
La libreria di monitoraggio TPU fornisce:
Osservabilità completa: accedi all'API Telemetry e alla suite di metriche, che forniscono informazioni dettagliate sul rendimento operativo e sui comportamenti specifici delle TPU.
Toolkit di diagnostica: fornisce un SDK e un'interfaccia a riga di comando (CLI) progettati per consentire il debug e l'analisi approfondita delle prestazioni delle risorse TPU.
Queste funzionalità di monitoraggio sono progettate per essere una soluzione di primo livello rivolta ai clienti, fornendoti gli strumenti essenziali per ottimizzare in modo efficace i tuoi workload TPU.
La libreria di monitoraggio TPU fornisce informazioni dettagliate sul rendimento dei carichi di lavoro di machine learning sull'hardware TPU. È progettato per aiutarti a comprendere l'utilizzo delle TPU, identificare i colli di bottiglia ed eseguire il debug dei problemi di prestazioni. Fornisce informazioni più dettagliate rispetto alle metriche di interruzione, goodput e altre metriche.
Inizia a utilizzare la libreria di monitoraggio TPU
Accedere a questi potenti approfondimenti è semplice. La funzionalità di monitoraggio della TPU è integrata nell'SDK LibTPU, pertanto è inclusa quando installi LibTPU.
Installare LibTPU
pip install libtpu
In alternativa, gli aggiornamenti di LibTPU sono coordinati con le release di JAX, il che significa che quando installi l'ultima release di JAX (rilasciata mensilmente), in genere viene installata l'ultima versione compatibile di LibTPU e le relative funzionalità.
Installare JAX
pip install -U "jax[tpu]"
Per gli utenti di PyTorch, l'installazione di PyTorch/XLA fornisce le funzionalità di monitoraggio più recenti di LibTPU e TPU.
Installare PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Per saperne di più sull'installazione di PyTorch/XLA, consulta la sezione Installazione nel repository GitHub di PyTorch/XLA.
Importa la libreria in Python
Per iniziare a utilizzare la libreria di monitoraggio TPU, devi importare il modulo libtpu nel codice Python.
from libtpu.sdk import tpumonitoring
Elenca tutte le funzionalità supportate
Elenca tutti i nomi delle metriche e le funzionalità che supportano:
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
Metriche supportate
Il seguente esempio di codice mostra come elencare tutti i nomi delle metriche supportate:
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
La tabella seguente mostra tutte le metriche e le relative definizioni:
| Metrica | Definizione | Nome della metrica per l'API | Valori di esempio |
|---|---|---|---|
| Utilizzo di Tensor Core | Misura la percentuale di utilizzo di TensorCore, calcolata come percentuale di operazioni incluse nelle operazioni TensorCore. Campionamento di 10 microsecondi ogni secondo. Non puoi modificare la frequenza di campionamento. Questa metrica consente di monitorare l'efficienza dei workload sui dispositivi TPU. |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# utilization percentage for accelerator ID 0-3. |
| Percentuale del ciclo di lavoro | Percentuale di tempo nell'ultimo periodo di campionamento (ogni 5 secondi; può
essere ottimizzata impostando il flag LIBTPU_INIT_ARG) durante il quale
l'acceleratore ha eseguito attivamente l'elaborazione (registrata con i cicli utilizzati per
eseguire i programmi HLO nell'ultimo periodo di campionamento). Questa metrica
rappresenta il livello di utilizzo di una TPU. La metrica viene emessa per chip.
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# Duty cycle percentage for accelerator ID 0-3. |
| HBM Capacity Total | Questa metrica indica la capacità totale HBM in byte. |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# Capacità HBM totale in byte collegata all'ID acceleratore 0-3. |
| Utilizzo capacità HBM | Questa metrica indica l'utilizzo della capacità HBM in byte nell'ultimo
periodo di campionamento (ogni 5 secondi; può essere modificato impostando
il flag LIBTPU_INIT_ARG).
|
hbm_capacity_usage
|
['100', '200', '300', '400']
# Utilizzo della capacità per HBM in byte collegati all'ID acceleratore 0-3. |
| Latenza di trasferimento del buffer | Latenze di trasferimento di rete per il traffico multislice su larga scala. Questa visualizzazione ti consente di comprendere l'ambiente di rendimento complessivo della rete. |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# Dimensione buffer, media, p50, p90, p99, p99.9 della distribuzione della latenza di trasferimento di rete. |
| Metriche di distribuzione del tempo di esecuzione delle operazioni di alto livello | Fornisce informazioni dettagliate sul rendimento dello stato di esecuzione del binario compilato HLO, consentendo il rilevamento della regressione e il debug a livello di modello. |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999. |
| Dimensioni della coda dello strumento di ottimizzazione di alto livello | Il monitoraggio delle dimensioni della coda di esecuzione HLO tiene traccia del numero di programmi HLO compilati in attesa o in fase di esecuzione. Questa metrica rivela la congestione della pipeline di esecuzione, consentendo l'identificazione di colli di bottiglia delle prestazioni nell'esecuzione hardware, nell'overhead dei driver o nell'allocazione delle risorse. |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# Measures queue size for CoreType-CoreID. |
| Latenza end-to-end collettiva | Questa metrica misura la latenza collettiva end-to-end su DCN in microsecondi, dall'host che avvia l'operazione a tutti i peer che ricevono l'output. Include la riduzione e l'invio dei dati lato host all'unità TPU. I risultati sono stringhe che descrivono in dettaglio le dimensioni, il tipo e le latenze medie, p50, p90, p95 e p99,9 del buffer. |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency. |
| Latenza di round trip a livello di trasporto | Distribuzione dei Round Trip Time (RTT) minimi osservati sulle connessioni TCP utilizzate da gRPC per il traffico TPU multislice. |
grpc_tcp_min_round_trip_times
|
['27.63, 29.03, 38.52, 41.63, 52.74']
# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs). |
| Velocità effettiva a livello di trasporto | Distribuzione cumulativa del throughput recente delle connessioni TCP utilizzate da gRPC per il traffico TPU multislice. |
grpc_tcp_delivery_rates
|
['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']
# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs). |
Lettura dei dati delle metriche
Per leggere i dati delle metriche, specifica il nome della metrica quando chiami la funzione
tpumonitoring.get_metric. Puoi inserire controlli delle metriche ad hoc nel codice con prestazioni scarse per identificare se i problemi di prestazioni derivano da software o hardware.
Il seguente esempio di codice mostra come leggere la metrica duty_cycle:
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
Utilizzare le metriche per controllare l'utilizzo della TPU
Gli esempi seguenti mostrano come utilizzare le metriche della libreria di monitoraggio TPU per monitorare l'utilizzo della TPU.
Monitorare il ciclo di servizio TPU durante l'addestramento JAX
Scenario: stai eseguendo uno script di addestramento JAX e vuoi monitorare la metrica duty_cycle_pct della TPU durante il processo di addestramento per verificare che le TPU vengano utilizzate in modo efficace. Puoi
registrare questa metrica periodicamente durante l'addestramento per monitorare l'utilizzo della TPU.
Il seguente esempio di codice mostra come monitorare il ciclo di lavoro della TPU durante l'addestramento JAX:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
Controlla l'utilizzo della HBM prima di eseguire l'inferenza JAX
Scenario: prima di eseguire l'inferenza con il modello JAX, controlla l'utilizzo attuale della HBM (High Bandwidth Memory) sulla TPU per verificare di avere memoria sufficiente e per ottenere una misurazione di base prima dell'inizio dell'inferenza.
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
Controllare le metriche di rete
Scenario: stai eseguendo un carico di lavoro multi-host e multi-slice e vuoi connetterti a uno dei pod GKE o delle TPU utilizzando SSH per visualizzare le metriche di rete mentre il carico di lavoro è in esecuzione. I comandi possono anche essere incorporati direttamente nel workload multihost.
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and training setup goes here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
# --- details here ---
# ==============================================================================
# Metric 1: TCP Delivery Rate
# ==============================================================================
# This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)
# Get the metric object
delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)
# Print the description provided by the library
print("Description:", delivery_rate_metric.description())
# Print the actual data payload
print("Data:", delivery_rate_metric.data())
# ==============================================================================
# Metric 2: TCP Minimum Round Trip Time (RTT)
# ==============================================================================
# This metric reflects the minimum RTT measured between sending a TCP packet
# and receiving the acknowledgement.
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)
# Get the metric object
min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)
# Print the description provided by the library
print("Description:", min_rtt_metric.description())
# Print the actual data payload
print("Data:", min_rtt_metric.data())
Frequenza di aggiornamento delle metriche TPU
La frequenza di aggiornamento delle metriche TPU è limitata a un minimo di un secondo. I dati delle metriche host vengono esportati a una frequenza fissa di 1 Hz. La latenza introdotta da questo processo di esportazione è trascurabile. Le metriche di runtime di LibTPU non sono soggette allo stesso vincolo di frequenza. Tuttavia, per coerenza, anche queste metriche vengono campionate a 1 Hz o 1 campione al secondo.
Modulo TPU-Z
TPU-Z è uno strumento di telemetria e debug per le TPU. Fornisce informazioni dettagliate
sullo stato di runtime per tutti i core TPU collegati a un host. La funzionalità viene fornita tramite il modulo tpuz, che fa parte del modulo libtpu.sdk nell'SDK Python libtpu. Il modulo fornisce uno snapshot
dello stato di ogni core.
Il caso d'uso principale di TPU-Z è la diagnosi di blocchi o deadlock nei carichi di lavoro TPU distribuiti. Puoi eseguire query sul servizio TPU-Z sugli host per acquisire lo stato di ogni core, confrontando i contatori di programma, le posizioni HLO e gli ID esecuzione in tutti i core per identificare le anomalie.
Utilizza la funzione get_core_state_summary() all'interno della libreria libtpu.sdk per
visualizzare le metriche TPU-Z:
summary = sdk.tpuz.get_core_state_summary()
L'output per le metriche TPU-Z viene fornito come dizionario. Di seguito è riportato un esempio troncato per un singolo core:
{
"host_name": "my-tpu-host-vm",
"core_states": {
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 4490,
"program_id": 3274167277388825310,
"run_id": 3
}
],
"program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
"queued_program_info": [],
"error_message": ""
}
// ...
}
}
Per recuperare informazioni sugli ottimizzatori di alto livello (HLO) su ogni core, imposta
il parametro include_hlo_info su True:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
L'output include informazioni HLO aggiuntive:
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 17776,
"tag": 3,
"tracemark": 2147483646,
"program_id": 3230481660274331500,
"run_id": 81,
"hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
"hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
}
],
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
"launch_id": 1394130914,
"queued_program_info": [
{
"run_id": 81,
"launch_id": 1394130914,
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
}
]
}
Metriche TPU-Z
La funzione get_core_state_summary restituisce le metriche TPU-Z sotto forma di
dizionario con la seguente struttura.
CurrentCoreStateSummary
Il dizionario CurrentCoreStateSummary fornisce un riepilogo dettagliato dello stato di un singolo core TPU.
| Campo | Tipo | Descrizione |
|---|---|---|
core_id |
dizionario | Un dizionario TpuCoreIdentifier che contiene informazioni sull'ID del core TPU. |
sequencer_info |
elenco di dizionari | Un elenco di dizionari SequencerInfo che descrivono lo stato di ogni sequencer sul core. |
program_fingerprint |
byte | L'impronta del programma in esecuzione su questo core. |
launch_id |
integer | L'ID di avvio del programma corrente o più recente. |
queued_program_info |
elenco di dizionari | Un elenco di dizionari QueuedProgramInfo per i programmi in coda per l'esecuzione. |
error_message |
string | Eventuali messaggi di errore per questo core. |
TpuCoreIdentifier
Il dizionario TpuCoreIdentifier fornisce informazioni sull'ID dei core all'interno del
sistema TPU.
| Campo | Tipo | Descrizione |
|---|---|---|
global_core_id |
integer | L'ID del core. |
chip_id |
integer | L'ID del chip a cui appartiene il core. |
core_on_chip |
dizionario | Un dizionario TpuCoreOnChip che descrive il tipo di core e il suo indice sul chip. |
TpuCoreOnChip
Il dizionario TpuCoreOnChip contiene informazioni sulle proprietà di un core
all'interno di un chip specifico.
| Campo | Tipo | Descrizione |
|---|---|---|
type |
string | Il tipo di core TPU. Ad esempio: TPU_CORE_TYPE_TENSOR_CORE. |
index |
integer | L'indice del core sul chip. |
SequencerInfo
Il dizionario SequencerInfo contiene informazioni sullo stato di un singolo
sequencer su un core.
| Campo | Tipo | Descrizione |
|---|---|---|
sequencer_type |
string | Il tipo di sequencer. Ad esempio: TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER. |
sequencer_index |
integer | L'indice del sequencer (se sono presenti più sequencer dello stesso tipo). |
pc |
integer | Il valore attuale del Program Counter. |
program_id |
integer | L'ID associato a un'istanza specifica di un programma avviato per l'esecuzione su un core TPU. |
run_id |
integer | L'ID esecuzione associato a un'istanza specifica dell'esecuzione di un programma su un core TPU. |
hlo_location |
string | Informazioni sulla posizione dell'ottimizzatore di alto livello. |
hlo_detailed_info |
string | Informazioni dettagliate sull'ottimizzatore di alto livello. |
QueuedProgramInfo
Il dizionario QueuedProgramInfo contiene informazioni sui programmi in coda
per l'esecuzione su un core.
| Campo | Tipo | Descrizione |
|---|---|---|
run_id |
integer | L'ID esecuzione del programma in coda. |
launch_id |
integer | L'ID lancio del programma in coda. |
program_fingerprint |
byte | L'impronta del programma in coda. |
TPU-Z con JAX
Puoi accedere alle metriche TPU-Z nei carichi di lavoro JAX tramite la libreria libtpu.sdk.
Il seguente script Python utilizza JAX per il calcolo dei tensori ad alte prestazioni,
mentre utilizza contemporaneamente l'SDK libtpu in un thread in background per monitorare
lo stato e l'attività dell'hardware TPU sottostante.
Includi i seguenti pacchetti Python:
import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk
La funzione monitor_tpu_status utilizza un thread in background per mostrare continuamente
lo stato operativo dei core TPU mentre l'applicazione principale esegue un
carico di lavoro JAX. Funge da strumento di diagnostica in tempo reale.
def monitor_tpu_status():
"""Monitors TPU status in a background thread."""
while monitoring_active:
try:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
if summary and 'core_states' in summary:
print(summary)
else:
print('WARNING: Call returned an empty or invalid summary.')
except RuntimeError as e:
print(f'FAIL: Error calling API: {e}')
except Exception as e:
print(f'FAIL: Unexpected error in monitor thread: {e}')
for _ in range(MONITORING_INTERVAL_SECONDS * 2):
if not monitoring_active:
break
time.sleep(0.5)
print('✅ Monitoring thread stopped.')
La funzione transformer_block implementa un livello completo dell'architettura Transformer, che è il blocco di base per gli LLM.
@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
"""A simplified but computationally intensive Transformer block."""
# Multi-head Self-Attention
qkv = jnp.dot(x, params['qkv_kernel'])
q, k, v = jnp.array_split(qkv, 3, axis=-1)
# Reshape for multi-head attention
q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
# Scaled dot-product attention
attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)
attention_output = jnp.dot(attention_output, params['o_kernel'])
# Residual connection and Layer Normalization 1
h1 = x + attention_output
h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
h1_norm = h1_norm / jnp.sqrt(
jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
)
# Feed-Forward Network
ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])
# Residual connection and Layer Normalization 2
h2 = h1_norm + ffn_output
h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
h2_norm = h2_norm / jnp.sqrt(
jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
)
return h2_norm
La funzione main orchestra la configurazione del calcolo JAX, avvia il
monitoraggio TPU in background ed esegue il ciclo principale del workload.
def main():
num_devices = jax.device_count()
print(f"Running on {num_devices} devices.")
batch_size = 128 * num_devices
seq_len = 512
embed_dim = 1024
ffn_dim = embed_dim * 4
key = jax.random.PRNGKey(0)
params = {
'qkv_kernel': jax.random.normal(
key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
),
'o_kernel': jax.random.normal(
key, (embed_dim, embed_dim), dtype=jnp.bfloat16
),
'ffn1_kernel': jax.random.normal(
key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
),
'ffn2_kernel': jax.random.normal(
key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
),
}
input_data = jax.random.normal(
key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
)
input_data = jax.device_put(input_data)
monitor_thread = threading.Thread(target=monitor_tpu_status)
monitor_thread.start()
print("Starting JAX computation loop...")
start_time = time.time()
iterations = 0
while time.time() - start_time < JOB_DURATION_SECONDS:
result = transformer_block(params, input_data)
result.block_until_ready()
iterations += 1
print(f' -> Jax iteration {iterations} complete.', end='\r')
print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")
global monitoring_active
monitoring_active = False
monitor_thread.join()
if __name__ == '__main__':
main()
Risoluzione dei problemi
Questa sezione fornisce informazioni per la risoluzione dei problemi per aiutarti a identificare e risolvere i problemi che potresti riscontrare durante l'utilizzo della libreria di monitoraggio TPU.
Funzionalità o metriche mancanti
Se non riesci a visualizzare alcune funzionalità o metriche, la causa più comune è una versione
obsoleta di libtpu. Le funzionalità e le metriche della libreria di monitoraggio TPU sono
incluse nelle release di libtpu e le versioni obsolete potrebbero non includere nuove
funzionalità e metriche.
Controlla la versione di libtpu in esecuzione nel tuo ambiente:
Riga di comando:
pip show libtpu
Python:
import libtpu
print(libtpu.__version__)
Se non utilizzi l'ultima versione di
libtpu, utilizza il seguente comando per aggiornare la libreria:
pip install --upgrade libtpu