Libreria di monitoraggio TPU

Ottieni informazioni approfondite sulle prestazioni e sul comportamento dell'hardware Cloud TPU con funzionalità di monitoraggio avanzate di TPU, basate direttamente sul livello software di base, LibTPU. Sebbene LibTPU comprenda driver, librerie di rete, il compilatore XLA e il runtime TPU per l'interazione con le TPU, questo documento si concentra sulla libreria di monitoraggio TPU.

La libreria di monitoraggio TPU fornisce:

  • Osservabilità completa: accedi all'API Telemetry e alla suite di metriche, che forniscono informazioni dettagliate sul rendimento operativo e sui comportamenti specifici delle TPU.

  • Toolkit di diagnostica: fornisce un SDK e un'interfaccia a riga di comando (CLI) progettati per consentire il debug e l'analisi approfondita delle prestazioni delle risorse TPU.

Queste funzionalità di monitoraggio sono progettate per essere una soluzione di primo livello rivolta ai clienti, fornendoti gli strumenti essenziali per ottimizzare in modo efficace i tuoi workload TPU.

La libreria di monitoraggio TPU fornisce informazioni dettagliate sul rendimento dei carichi di lavoro di machine learning sull'hardware TPU. È progettato per aiutarti a comprendere l'utilizzo delle TPU, identificare i colli di bottiglia ed eseguire il debug dei problemi di prestazioni. Fornisce informazioni più dettagliate rispetto alle metriche di interruzione, goodput e altre metriche.

Inizia a utilizzare la libreria di monitoraggio TPU

Accedere a questi potenti approfondimenti è semplice. La funzionalità di monitoraggio della TPU è integrata nell'SDK LibTPU, pertanto è inclusa quando installi LibTPU.

Installare LibTPU

pip install libtpu

In alternativa, gli aggiornamenti di LibTPU sono coordinati con le release di JAX, il che significa che quando installi l'ultima release di JAX (rilasciata mensilmente), in genere viene installata l'ultima versione compatibile di LibTPU e le relative funzionalità.

Installare JAX

pip install -U "jax[tpu]"

Per gli utenti di PyTorch, l'installazione di PyTorch/XLA fornisce le funzionalità di monitoraggio più recenti di LibTPU e TPU.

Installare PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Per saperne di più sull'installazione di PyTorch/XLA, consulta la sezione Installazione nel repository GitHub di PyTorch/XLA.

Importa la libreria in Python

Per iniziare a utilizzare la libreria di monitoraggio TPU, devi importare il modulo libtpu nel codice Python.

from libtpu.sdk import tpumonitoring

Elenca tutte le funzionalità supportate

Elenca tutti i nomi delle metriche e le funzionalità che supportano:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

Metriche supportate

Il seguente esempio di codice mostra come elencare tutti i nomi delle metriche supportate:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

La tabella seguente mostra tutte le metriche e le relative definizioni:

Metrica Definizione Nome della metrica per l'API Valori di esempio
Utilizzo di Tensor Core Misura la percentuale di utilizzo di TensorCore, calcolata come percentuale di operazioni incluse nelle operazioni TensorCore. Campionamento di 10 microsecondi ogni secondo. Non puoi modificare la frequenza di campionamento. Questa metrica consente di monitorare l'efficienza dei workload sui dispositivi TPU. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# utilization percentage for accelerator ID 0-3.
Percentuale del ciclo di lavoro Percentuale di tempo nell'ultimo periodo di campionamento (ogni 5 secondi; può essere ottimizzata impostando il flag LIBTPU_INIT_ARG) durante il quale l'acceleratore ha eseguito attivamente l'elaborazione (registrata con i cicli utilizzati per eseguire i programmi HLO nell'ultimo periodo di campionamento). Questa metrica rappresenta il livello di utilizzo di una TPU. La metrica viene emessa per chip. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Duty cycle percentage for accelerator ID 0-3.
HBM Capacity Total Questa metrica indica la capacità totale HBM in byte. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Capacità HBM totale in byte collegata all'ID acceleratore 0-3.
Utilizzo capacità HBM Questa metrica indica l'utilizzo della capacità HBM in byte nell'ultimo periodo di campionamento (ogni 5 secondi; può essere modificato impostando il flag LIBTPU_INIT_ARG). hbm_capacity_usage ['100', '200', '300', '400']

# Utilizzo della capacità per HBM in byte collegati all'ID acceleratore 0-3.
Latenza di trasferimento del buffer Latenze di trasferimento di rete per il traffico multislice su larga scala. Questa visualizzazione ti consente di comprendere l'ambiente di rendimento complessivo della rete. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# Dimensione buffer, media, p50, p90, p99, p99.9 della distribuzione della latenza di trasferimento di rete.
Metriche di distribuzione del tempo di esecuzione delle operazioni di alto livello Fornisce informazioni dettagliate sul rendimento dello stato di esecuzione del binario compilato HLO, consentendo il rilevamento della regressione e il debug a livello di modello. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999.
Dimensioni della coda dello strumento di ottimizzazione di alto livello Il monitoraggio delle dimensioni della coda di esecuzione HLO tiene traccia del numero di programmi HLO compilati in attesa o in fase di esecuzione. Questa metrica rivela la congestione della pipeline di esecuzione, consentendo l'identificazione di colli di bottiglia delle prestazioni nell'esecuzione hardware, nell'overhead dei driver o nell'allocazione delle risorse. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Measures queue size for CoreType-CoreID.
Latenza end-to-end collettiva Questa metrica misura la latenza collettiva end-to-end su DCN in microsecondi, dall'host che avvia l'operazione a tutti i peer che ricevono l'output. Include la riduzione e l'invio dei dati lato host all'unità TPU. I risultati sono stringhe che descrivono in dettaglio le dimensioni, il tipo e le latenze medie, p50, p90, p95 e p99,9 del buffer. collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency.
Latenza di round trip a livello di trasporto Distribuzione dei Round Trip Time (RTT) minimi osservati sulle connessioni TCP utilizzate da gRPC per il traffico TPU multislice. grpc_tcp_min_round_trip_times ['27.63, 29.03, 38.52, 41.63, 52.74']

# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs).
Velocità effettiva a livello di trasporto Distribuzione cumulativa del throughput recente delle connessioni TCP utilizzate da gRPC per il traffico TPU multislice. grpc_tcp_delivery_rates ['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']

# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs).

Lettura dei dati delle metriche

Per leggere i dati delle metriche, specifica il nome della metrica quando chiami la funzione tpumonitoring.get_metric. Puoi inserire controlli delle metriche ad hoc nel codice con prestazioni scarse per identificare se i problemi di prestazioni derivano da software o hardware.

Il seguente esempio di codice mostra come leggere la metrica duty_cycle:

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

Utilizzare le metriche per controllare l'utilizzo della TPU

Gli esempi seguenti mostrano come utilizzare le metriche della libreria di monitoraggio TPU per monitorare l'utilizzo della TPU.

Monitorare il ciclo di servizio TPU durante l'addestramento JAX

Scenario: stai eseguendo uno script di addestramento JAX e vuoi monitorare la metrica duty_cycle_pct della TPU durante il processo di addestramento per verificare che le TPU vengano utilizzate in modo efficace. Puoi registrare questa metrica periodicamente durante l'addestramento per monitorare l'utilizzo della TPU.

Il seguente esempio di codice mostra come monitorare il ciclo di lavoro della TPU durante l'addestramento JAX:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

Controlla l'utilizzo della HBM prima di eseguire l'inferenza JAX

Scenario: prima di eseguire l'inferenza con il modello JAX, controlla l'utilizzo attuale della HBM (High Bandwidth Memory) sulla TPU per verificare di avere memoria sufficiente e per ottenere una misurazione di base prima dell'inizio dell'inferenza.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

Controllare le metriche di rete

Scenario: stai eseguendo un carico di lavoro multi-host e multi-slice e vuoi connetterti a uno dei pod GKE o delle TPU utilizzando SSH per visualizzare le metriche di rete mentre il carico di lavoro è in esecuzione. I comandi possono anche essere incorporati direttamente nel workload multihost.

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

 # --- Your JAX model and training setup goes here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)
     # --- details here ---

    # ==============================================================================
    # Metric 1: TCP Delivery Rate
    # ==============================================================================
    # This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)

        # Get the metric object
        delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)

        # Print the description provided by the library
        print("Description:", delivery_rate_metric.description())

        # Print the actual data payload
        print("Data:", delivery_rate_metric.data())

    # ==============================================================================
    # Metric 2: TCP Minimum Round Trip Time (RTT)
    # ==============================================================================
    # This metric reflects the minimum RTT measured between sending a TCP packet
    # and receiving the acknowledgement.
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)

        # Get the metric object
        min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)

        # Print the description provided by the library
        print("Description:", min_rtt_metric.description())

        # Print the actual data payload
        print("Data:", min_rtt_metric.data())

Frequenza di aggiornamento delle metriche TPU

La frequenza di aggiornamento delle metriche TPU è limitata a un minimo di un secondo. I dati delle metriche host vengono esportati a una frequenza fissa di 1 Hz. La latenza introdotta da questo processo di esportazione è trascurabile. Le metriche di runtime di LibTPU non sono soggette allo stesso vincolo di frequenza. Tuttavia, per coerenza, anche queste metriche vengono campionate a 1 Hz o 1 campione al secondo.

Modulo TPU-Z

TPU-Z è uno strumento di telemetria e debug per le TPU. Fornisce informazioni dettagliate sullo stato di runtime per tutti i core TPU collegati a un host. La funzionalità viene fornita tramite il modulo tpuz, che fa parte del modulo libtpu.sdk nell'SDK Python libtpu. Il modulo fornisce uno snapshot dello stato di ogni core.

Il caso d'uso principale di TPU-Z è la diagnosi di blocchi o deadlock nei carichi di lavoro TPU distribuiti. Puoi eseguire query sul servizio TPU-Z sugli host per acquisire lo stato di ogni core, confrontando i contatori di programma, le posizioni HLO e gli ID esecuzione in tutti i core per identificare le anomalie.

Utilizza la funzione get_core_state_summary() all'interno della libreria libtpu.sdk per visualizzare le metriche TPU-Z:

summary = sdk.tpuz.get_core_state_summary()

L'output per le metriche TPU-Z viene fornito come dizionario. Di seguito è riportato un esempio troncato per un singolo core:

{
  "host_name": "my-tpu-host-vm",
  "core_states": {
    "1": {
      "core_id": {
        "global_core_id": 1,
        "chip_id": 0,
        "core_on_chip": {
          "type": "TPU_CORE_TYPE_TENSOR_CORE",
          "index": 1
        }
      },
      "sequencer_info": [
        {
          "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
          "sequencer_index": 0,
          "pc": 4490,
          "program_id": 3274167277388825310,
          "run_id": 3
        }
      ],
      "program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
      "queued_program_info": [],
      "error_message": ""
    }
    // ...
  }
}

Per recuperare informazioni sugli ottimizzatori di alto livello (HLO) su ogni core, imposta il parametro include_hlo_info su True:

summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)

L'output include informazioni HLO aggiuntive:

"1": {
    "core_id": {
      "global_core_id": 1,
      "chip_id": 0,
      "core_on_chip": {
        "type": "TPU_CORE_TYPE_TENSOR_CORE",
        "index": 1
      }
    },
    "sequencer_info": [
      {
        "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
        "sequencer_index": 0,
        "pc": 17776,
        "tag": 3,
        "tracemark": 2147483646,
        "program_id": 3230481660274331500,
        "run_id": 81,
        "hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
        "hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
      }
    ],
    "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
    "launch_id": 1394130914,
    "queued_program_info": [
      {
        "run_id": 81,
        "launch_id": 1394130914,
        "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
      }
    ]
  }

Metriche TPU-Z

La funzione get_core_state_summary restituisce le metriche TPU-Z sotto forma di dizionario con la seguente struttura.

CurrentCoreStateSummary

Il dizionario CurrentCoreStateSummary fornisce un riepilogo dettagliato dello stato di un singolo core TPU.

Campo Tipo Descrizione
core_id dizionario Un dizionario TpuCoreIdentifier che contiene informazioni sull'ID del core TPU.
sequencer_info elenco di dizionari Un elenco di dizionari SequencerInfo che descrivono lo stato di ogni sequencer sul core.
program_fingerprint byte L'impronta del programma in esecuzione su questo core.
launch_id integer L'ID di avvio del programma corrente o più recente.
queued_program_info elenco di dizionari Un elenco di dizionari QueuedProgramInfo per i programmi in coda per l'esecuzione.
error_message string Eventuali messaggi di errore per questo core.

TpuCoreIdentifier

Il dizionario TpuCoreIdentifier fornisce informazioni sull'ID dei core all'interno del sistema TPU.

Campo Tipo Descrizione
global_core_id integer L'ID del core.
chip_id integer L'ID del chip a cui appartiene il core.
core_on_chip dizionario Un dizionario TpuCoreOnChip che descrive il tipo di core e il suo indice sul chip.

TpuCoreOnChip

Il dizionario TpuCoreOnChip contiene informazioni sulle proprietà di un core all'interno di un chip specifico.

Campo Tipo Descrizione
type string Il tipo di core TPU. Ad esempio: TPU_CORE_TYPE_TENSOR_CORE.
index integer L'indice del core sul chip.

SequencerInfo

Il dizionario SequencerInfo contiene informazioni sullo stato di un singolo sequencer su un core.

Campo Tipo Descrizione
sequencer_type string Il tipo di sequencer. Ad esempio: TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER.
sequencer_index integer L'indice del sequencer (se sono presenti più sequencer dello stesso tipo).
pc integer Il valore attuale del Program Counter.
program_id integer L'ID associato a un'istanza specifica di un programma avviato per l'esecuzione su un core TPU.
run_id integer L'ID esecuzione associato a un'istanza specifica dell'esecuzione di un programma su un core TPU.
hlo_location string Informazioni sulla posizione dell'ottimizzatore di alto livello.
hlo_detailed_info string Informazioni dettagliate sull'ottimizzatore di alto livello.

QueuedProgramInfo

Il dizionario QueuedProgramInfo contiene informazioni sui programmi in coda per l'esecuzione su un core.

Campo Tipo Descrizione
run_id integer L'ID esecuzione del programma in coda.
launch_id integer L'ID lancio del programma in coda.
program_fingerprint byte L'impronta del programma in coda.

TPU-Z con JAX

Puoi accedere alle metriche TPU-Z nei carichi di lavoro JAX tramite la libreria libtpu.sdk. Il seguente script Python utilizza JAX per il calcolo dei tensori ad alte prestazioni, mentre utilizza contemporaneamente l'SDK libtpu in un thread in background per monitorare lo stato e l'attività dell'hardware TPU sottostante.

Includi i seguenti pacchetti Python:

import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk

La funzione monitor_tpu_status utilizza un thread in background per mostrare continuamente lo stato operativo dei core TPU mentre l'applicazione principale esegue un carico di lavoro JAX. Funge da strumento di diagnostica in tempo reale.

def monitor_tpu_status():
  """Monitors TPU status in a background thread."""

  while monitoring_active:
    try:
      summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
      if summary and 'core_states' in summary:
        print(summary)
      else:
        print('WARNING: Call returned an empty or invalid summary.')
    except RuntimeError as e:
      print(f'FAIL: Error calling API: {e}')
    except Exception as e:
      print(f'FAIL: Unexpected error in monitor thread: {e}')

    for _ in range(MONITORING_INTERVAL_SECONDS * 2):
      if not monitoring_active:
        break
      time.sleep(0.5)
  print('✅ Monitoring thread stopped.')

La funzione transformer_block implementa un livello completo dell'architettura Transformer, che è il blocco di base per gli LLM.

@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
  """A simplified but computationally intensive Transformer block."""
  # Multi-head Self-Attention
  qkv = jnp.dot(x, params['qkv_kernel'])
  q, k, v = jnp.array_split(qkv, 3, axis=-1)

  # Reshape for multi-head attention
  q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)

  # Scaled dot-product attention
  attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
  attention_weights = jax.nn.softmax(attention_scores, axis=-1)
  attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
  attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)

  attention_output = jnp.dot(attention_output, params['o_kernel'])

  # Residual connection and Layer Normalization 1
  h1 = x + attention_output
  h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
  h1_norm = h1_norm / jnp.sqrt(
      jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
  )

  # Feed-Forward Network
  ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
  ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])

  # Residual connection and Layer Normalization 2
  h2 = h1_norm + ffn_output
  h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
  h2_norm = h2_norm / jnp.sqrt(
      jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
  )

  return h2_norm

La funzione main orchestra la configurazione del calcolo JAX, avvia il monitoraggio TPU in background ed esegue il ciclo principale del workload.

def main():
  num_devices = jax.device_count()
  print(f"Running on {num_devices} devices.")

  batch_size = 128 * num_devices
  seq_len = 512
  embed_dim = 1024
  ffn_dim = embed_dim * 4

  key = jax.random.PRNGKey(0)

  params = {
      'qkv_kernel': jax.random.normal(
          key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
      ),
      'o_kernel': jax.random.normal(
          key, (embed_dim, embed_dim), dtype=jnp.bfloat16
      ),
      'ffn1_kernel': jax.random.normal(
          key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
      ),
      'ffn2_kernel': jax.random.normal(
          key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
      ),
  }
  input_data = jax.random.normal(
      key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
  )
  input_data = jax.device_put(input_data)
  monitor_thread = threading.Thread(target=monitor_tpu_status)
  monitor_thread.start()
  print("Starting JAX computation loop...")
  start_time = time.time()
  iterations = 0
  while time.time() - start_time < JOB_DURATION_SECONDS:
    result = transformer_block(params, input_data)
    result.block_until_ready()
    iterations += 1
    print(f'  -> Jax iteration {iterations} complete.', end='\r')

  print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")

  global monitoring_active
  monitoring_active = False
  monitor_thread.join()

if __name__ == '__main__':
  main()

Risoluzione dei problemi

Questa sezione fornisce informazioni per la risoluzione dei problemi per aiutarti a identificare e risolvere i problemi che potresti riscontrare durante l'utilizzo della libreria di monitoraggio TPU.

Funzionalità o metriche mancanti

Se non riesci a visualizzare alcune funzionalità o metriche, la causa più comune è una versione obsoleta di libtpu. Le funzionalità e le metriche della libreria di monitoraggio TPU sono incluse nelle release di libtpu e le versioni obsolete potrebbero non includere nuove funzionalità e metriche.

Controlla la versione di libtpu in esecuzione nel tuo ambiente:

Riga di comando:

pip show libtpu

Python:

import libtpu

print(libtpu.__version__)

Se non utilizzi l'ultima versione di libtpu, utilizza il seguente comando per aggiornare la libreria:

pip install --upgrade libtpu