Biblioteca de monitorização de TPUs

Desbloqueie estatísticas detalhadas sobre o desempenho e o comportamento do hardware da Cloud TPU com capacidades de monitorização avançadas da TPU, criadas diretamente na camada de software fundamental, a LibTPU. Embora a LibTPU abranja controladores, bibliotecas de rede, o compilador XLA e o tempo de execução da TPU para interagir com as TPUs, o foco deste documento é a biblioteca de monitorização da TPU.

A biblioteca de monitorização de TPU oferece:

  • Observabilidade abrangente: aceda à API de telemetria e ao conjunto de métricas, que fornecem estatísticas detalhadas sobre o desempenho operacional e os comportamentos específicos das suas TPUs.

  • Kits de ferramentas de diagnóstico: fornece um SDK e uma interface de linhas de comando (CLI) concebidos para permitir a depuração e a análise detalhada do desempenho dos seus recursos de TPU.

Estas funcionalidades de monitorização foram concebidas para serem uma solução de nível superior orientada para o cliente, que lhe oferece as ferramentas essenciais para otimizar os seus encargos de trabalho de TPU de forma eficaz.

A biblioteca de monitorização da TPU fornece informações detalhadas sobre o desempenho das cargas de trabalho de aprendizagem automática no hardware da TPU. Foi concebida para ajudar a compreender a utilização da TPU, identificar gargalos e depurar problemas de desempenho. Dá-lhe informações mais detalhadas do que as métricas de interrupção, as métricas de débito útil e outras métricas.

Comece a usar a biblioteca de monitorização de TPUs

Aceder a estas estatísticas avançadas é simples. A funcionalidade de monitorização da TPU está integrada no SDK LibTPU, pelo que a funcionalidade está incluída quando instala o LibTPU.

Instale a LibTPU

pip install libtpu

Em alternativa, as atualizações da LibTPU são coordenadas com os lançamentos do JAX, o que significa que, quando instala o lançamento mais recente do JAX (lançado mensalmente), este fixa normalmente a versão mais recente da LibTPU compatível e as respetivas funcionalidades.

Instale o JAX

pip install -U "jax[tpu]"

Para os utilizadores do PyTorch, a instalação do PyTorch/XLA fornece a funcionalidade de monitorização do LibTPU e TPU mais recente.

Instale o PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Para mais informações sobre a instalação do PyTorch/XLA, consulte a secção Instalação no repositório do GitHub do PyTorch/XLA.

Importe a biblioteca no Python

Para começar a usar a biblioteca de monitorização de TPUs, tem de importar o módulo libtpu no seu código Python.

from libtpu.sdk import tpumonitoring

Apresentar todas as funcionalidades suportadas

Indique todos os nomes das métricas e a funcionalidade que suportam:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

Métricas compatíveis

O seguinte exemplo de código mostra como listar todos os nomes de métricas suportados:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

A tabela seguinte mostra todas as métricas e as respetivas definições:

Métrica Definição Nome da métrica para a API Valores de exemplo
Utilização do núcleo Tensor Mede a percentagem da utilização dos TensorCores, calculada como a percentagem de operações que fazem parte das operações dos TensorCores. Amostragem de 10 microssegundos a cada 1 segundo. Não pode modificar a taxa de amostragem. Esta métrica permite-lhe monitorizar a eficiência das suas cargas de trabalho em dispositivos TPU. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# utilization percentage for accelerator ID 0-3.
Percentagem do ciclo de atividade Percentagem do tempo durante o período de amostragem anterior (a cada 5 segundos; pode ser ajustada através da definição da flag LIBTPU_INIT_ARG) durante o qual o acelerador estava a processar ativamente (registado com os ciclos usados para executar programas HLO durante o último período de amostragem). Esta métrica representa o nível de ocupação de uma TPU. A métrica é emitida por chip. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Percentagem do ciclo de trabalho para o ID do acelerador 0-3.
HBM Capacity Total Esta métrica indica a capacidade total de HBM em bytes. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Capacidade total de HBM em bytes associada ao ID do acelerador 0-3.
Utilização da capacidade de HBM Esta métrica comunica a utilização da capacidade de HBM em bytes durante o período de amostragem anterior (a cada 5 segundos; pode ser ajustada através da definição da flag LIBTPU_INIT_ARG). hbm_capacity_usage ['100', '200', '300', '400']

# Utilização da capacidade para HBM em bytes anexados ao ID do acelerador 0-3.
Latência de transferência da memória intermédia Latências de transferência de rede para tráfego multislice de grande escala. Esta visualização permite-lhe compreender o ambiente de desempenho geral da rede. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution.
Métricas de distribuição do tempo de execução de operações de nível elevado Fornece estatísticas de desempenho detalhadas sobre o estado de execução do binário compilado HLO, o que permite a deteção de regressão e a depuração ao nível do modelo. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# A distribuição da duração do tempo de execução do HLO para CoreType-CoreID com média, p50, p90, p95 e p999.
Tamanho da fila do otimizador de nível elevado A monitorização do tamanho da fila de execução do HLO acompanha o número de programas HLO compilados que estão a aguardar ou em execução. Esta métrica revela o congestionamento do pipeline de execução, o que permite identificar restrições de desempenho na execução de hardware, sobrecarga do controlador ou atribuição de recursos. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Measures queue size for CoreType-CoreID.
Latência completa coletiva Esta métrica mede a latência coletiva ponto a ponto na DCN em microssegundos, desde o anfitrião que inicia a operação até todos os pares receberem o resultado. Inclui a redução de dados no anfitrião e o envio de resultados para a TPU. Os resultados são strings que detalham o tamanho da memória intermédia, o tipo e as latências média, p50, p90, p95 e p99,9. collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency.
Latência de ida e volta na camada de transporte Distribuição dos tempos de resposta (RTT) mínimos observados nas ligações TCP usadas pelo gRPC para tráfego de TPU de várias fatias. grpc_tcp_min_round_trip_times ['27.63, 29.03, 38.52, 41.63, 52.74']

# Representa a média, os percentis 50, 90, 95 e 99, 9 da distribuição em microssegundos (µs).
Tráfego transmitido na camada de transporte Distribuição cumulativa da taxa de transferência recente de ligações TCP usadas pelo gRPC para tráfego de TPU de vários fragmentos. grpc_tcp_delivery_rates ['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']

# Representa a média, os percentis 50, 90, 95 e 99, 9 da distribuição em microssegundos (µs).

Ler dados de métricas

Para ler dados de métricas, especifique o nome da métrica quando chamar a função tpumonitoring.get_metric. Pode inserir verificações de métricas ad hoc no código de baixo desempenho para identificar se os problemas de desempenho têm origem no software ou no hardware.

O seguinte exemplo de código mostra como ler a métrica duty_cycle:

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

Use métricas para verificar a utilização da TPU

Os exemplos seguintes mostram como usar métricas da biblioteca de monitorização de TPUs para acompanhar a utilização de TPUs.

Monitorize o ciclo de serviço da TPU durante a preparação do JAX

Cenário: está a executar um script de preparação do JAX e quer monitorizar a métrica duty_cycle_pct da TPU ao longo do processo de preparação para confirmar que as TPUs estão a ser usadas de forma eficaz. Pode registar esta métrica periodicamente durante a preparação para monitorizar a utilização da TPU.

O seguinte exemplo de código mostra como monitorizar o ciclo de serviço da TPU durante o treino do JAX:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

Verifique a utilização da HBM antes de executar a inferência JAX

Cenário: antes de executar a inferência com o seu modelo JAX, verifique a utilização atual da HBM (memória de largura de banda elevada) na TPU para confirmar que tem memória suficiente disponível e para obter uma medição de base antes de iniciar a inferência.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

Verifique as métricas de rede

Cenário: está a executar uma carga de trabalho com vários anfitriões e várias fatias e quer estabelecer ligação a um dos pods do GKE ou às TPUs através de SSH para ver as métricas de rede enquanto a carga de trabalho está em execução. Os comandos também podem ser incorporados diretamente na carga de trabalho com vários anfitriões.

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

 # --- Your JAX model and training setup goes here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)
     # --- details here ---

    # ==============================================================================
    # Metric 1: TCP Delivery Rate
    # ==============================================================================
    # This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)

        # Get the metric object
        delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)

        # Print the description provided by the library
        print("Description:", delivery_rate_metric.description())

        # Print the actual data payload
        print("Data:", delivery_rate_metric.data())

    # ==============================================================================
    # Metric 2: TCP Minimum Round Trip Time (RTT)
    # ==============================================================================
    # This metric reflects the minimum RTT measured between sending a TCP packet
    # and receiving the acknowledgement.
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)

        # Get the metric object
        min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)

        # Print the description provided by the library
        print("Description:", min_rtt_metric.description())

        # Print the actual data payload
        print("Data:", min_rtt_metric.data())

Frequência de atualização das métricas de TPU

A frequência de atualização das métricas da TPU está limitada a um mínimo de um segundo. Os dados de métricas do anfitrião são exportados a uma frequência fixa de 1 Hz. A latência introduzida por este processo de exportação é insignificante. As métricas de tempo de execução da LibTPU não estão sujeitas à mesma restrição de frequência. No entanto, para manter a consistência, estas métricas também são amostradas a 1 Hz ou 1 amostra por segundo.

Módulo TPU-Z

O TPU-Z é uma funcionalidade de telemetria e depuração para TPUs. Fornece informações detalhadas sobre o estado de tempo de execução para todos os núcleos da TPU anexados a um anfitrião. A funcionalidade é fornecida através do módulo tpuz, que faz parte do módulo libtpu.sdk no SDK Python libtpu. O módulo fornece uma vista geral do estado de cada núcleo.

O exemplo de utilização principal do TPU-Z é o diagnóstico de bloqueios ou impasses em cargas de trabalho de TPU distribuídas. Pode consultar o serviço TPU-Z em anfitriões para capturar o estado de cada núcleo, comparando os contadores de programas, as localizações HLO e os IDs de execução em todos os núcleos para identificar anomalias.

Use a função get_core_state_summary() na biblioteca libtpu.sdk para apresentar as métricas de TPU-Z:

summary = sdk.tpuz.get_core_state_summary()

O resultado das métricas de TPU-Z é fornecido como um dicionário. Segue-se um exemplo reduzido para um único núcleo:

{
  "host_name": "my-tpu-host-vm",
  "core_states": {
    "1": {
      "core_id": {
        "global_core_id": 1,
        "chip_id": 0,
        "core_on_chip": {
          "type": "TPU_CORE_TYPE_TENSOR_CORE",
          "index": 1
        }
      },
      "sequencer_info": [
        {
          "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
          "sequencer_index": 0,
          "pc": 4490,
          "program_id": 3274167277388825310,
          "run_id": 3
        }
      ],
      "program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
      "queued_program_info": [],
      "error_message": ""
    }
    // ...
  }
}

Para obter informações sobre os otimizadores de nível elevado (HLO) em cada núcleo, defina o parâmetro include_hlo_info como True:

summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)

A saída inclui informações de HLO adicionais:

"1": {
    "core_id": {
      "global_core_id": 1,
      "chip_id": 0,
      "core_on_chip": {
        "type": "TPU_CORE_TYPE_TENSOR_CORE",
        "index": 1
      }
    },
    "sequencer_info": [
      {
        "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
        "sequencer_index": 0,
        "pc": 17776,
        "tag": 3,
        "tracemark": 2147483646,
        "program_id": 3230481660274331500,
        "run_id": 81,
        "hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
        "hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
      }
    ],
    "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
    "launch_id": 1394130914,
    "queued_program_info": [
      {
        "run_id": 81,
        "launch_id": 1394130914,
        "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
      }
    ]
  }

Métricas do TPU-Z

A função get_core_state_summary devolve métricas de TPU-Z sob a forma de um dicionário com a seguinte estrutura.

CurrentCoreStateSummary

O dicionário CurrentCoreStateSummary fornece um resumo detalhado do estado de um núcleo de TPU individual.

Campo Tipo Descrição
core_id dicionário Um TpuCoreIdentifierdicionário que contém informações de ID sobre o núcleo da TPU.
sequencer_info lista de dicionários Uma lista de dicionários SequencerInfo, que descrevem o estado de cada sequenciador no núcleo.
program_fingerprint bytes A impressão digital do programa que está a ser executado neste núcleo.
launch_id número inteiro O ID de lançamento do programa atual ou mais recente.
queued_program_info lista de dicionários Uma lista de QueuedProgramInfo dicionários para programas colocados em fila para execução.
error_message de string Quaisquer mensagens de erro para este núcleo.

TpuCoreIdentifier

O dicionário TpuCoreIdentifier fornece informações de ID para núcleos no sistema de TPU.

Campo Tipo Descrição
global_core_id número inteiro O ID do núcleo.
chip_id número inteiro O ID do chip ao qual o núcleo pertence.
core_on_chip dicionário Um TpuCoreOnChip dicionário que descreve o tipo do núcleo e o respetivo índice no chip.

TpuCoreOnChip

O dicionário TpuCoreOnChip contém informações sobre as propriedades de um núcleo num chip específico.

Campo Tipo Descrição
type de string O tipo de núcleo da TPU. Por exemplo: TPU_CORE_TYPE_TENSOR_CORE.
index número inteiro O índice do núcleo no chip.

SequencerInfo

O dicionário SequencerInfo contém informações sobre o estado de um único sequenciador num núcleo.

Campo Tipo Descrição
sequencer_type de string O tipo de sequenciador. Por exemplo: TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER.
sequencer_index número inteiro O índice do sequenciador (se existirem vários do mesmo tipo).
pc número inteiro O valor atual do contador do programa.
program_id número inteiro O ID associado a uma instância específica de um programa que está a ser iniciado para execução num núcleo da TPU.
run_id número inteiro O ID de execução associado a uma instância específica da execução de um programa num núcleo da TPU.
hlo_location de string Informações de localização do otimizador de nível superior.
hlo_detailed_info de string Informações detalhadas do otimizador de nível superior.

QueuedProgramInfo

O dicionário QueuedProgramInfo contém informações sobre programas em fila de espera para execução num núcleo.

Campo Tipo Descrição
run_id número inteiro O ID de execução do programa em fila.
launch_id número inteiro O ID de lançamento do programa em fila.
program_fingerprint bytes A impressão digital do programa em fila.

TPU-Z com JAX

Pode aceder às métricas do TPU-Z em cargas de trabalho JAX através da biblioteca libtpu.sdk. O seguinte script Python usa o JAX para a computação de tensores de elevado desempenho, enquanto usa simultaneamente o SDK libtpu num processo em segundo plano para monitorizar o estado e a atividade do hardware da TPU subjacente.

Inclua os seguintes pacotes Python:

import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk

A função monitor_tpu_status usa um tópico em segundo plano para mostrar continuamente o estado operacional dos núcleos das TPUs enquanto a aplicação principal executa uma carga de trabalho JAX. Funciona como uma ferramenta de diagnóstico em tempo real.

def monitor_tpu_status():
  """Monitors TPU status in a background thread."""

  while monitoring_active:
    try:
      summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
      if summary and 'core_states' in summary:
        print(summary)
      else:
        print('WARNING: Call returned an empty or invalid summary.')
    except RuntimeError as e:
      print(f'FAIL: Error calling API: {e}')
    except Exception as e:
      print(f'FAIL: Unexpected error in monitor thread: {e}')

    for _ in range(MONITORING_INTERVAL_SECONDS * 2):
      if not monitoring_active:
        break
      time.sleep(0.5)
  print('✅ Monitoring thread stopped.')

A função transformer_block implementa uma camada completa da arquitetura Transformer, que é a base fundamental para os MDIs.

@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
  """A simplified but computationally intensive Transformer block."""
  # Multi-head Self-Attention
  qkv = jnp.dot(x, params['qkv_kernel'])
  q, k, v = jnp.array_split(qkv, 3, axis=-1)

  # Reshape for multi-head attention
  q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)

  # Scaled dot-product attention
  attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
  attention_weights = jax.nn.softmax(attention_scores, axis=-1)
  attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
  attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)

  attention_output = jnp.dot(attention_output, params['o_kernel'])

  # Residual connection and Layer Normalization 1
  h1 = x + attention_output
  h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
  h1_norm = h1_norm / jnp.sqrt(
      jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
  )

  # Feed-Forward Network
  ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
  ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])

  # Residual connection and Layer Normalization 2
  h2 = h1_norm + ffn_output
  h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
  h2_norm = h2_norm / jnp.sqrt(
      jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
  )

  return h2_norm

A função main orquestra a configuração da computação JAX, inicia a monitorização da TPU em segundo plano e executa o ciclo de carga de trabalho principal.

def main():
  num_devices = jax.device_count()
  print(f"Running on {num_devices} devices.")

  batch_size = 128 * num_devices
  seq_len = 512
  embed_dim = 1024
  ffn_dim = embed_dim * 4

  key = jax.random.PRNGKey(0)

  params = {
      'qkv_kernel': jax.random.normal(
          key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
      ),
      'o_kernel': jax.random.normal(
          key, (embed_dim, embed_dim), dtype=jnp.bfloat16
      ),
      'ffn1_kernel': jax.random.normal(
          key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
      ),
      'ffn2_kernel': jax.random.normal(
          key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
      ),
  }
  input_data = jax.random.normal(
      key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
  )
  input_data = jax.device_put(input_data)
  monitor_thread = threading.Thread(target=monitor_tpu_status)
  monitor_thread.start()
  print("Starting JAX computation loop...")
  start_time = time.time()
  iterations = 0
  while time.time() - start_time < JOB_DURATION_SECONDS:
    result = transformer_block(params, input_data)
    result.block_until_ready()
    iterations += 1
    print(f'  -> Jax iteration {iterations} complete.', end='\r')

  print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")

  global monitoring_active
  monitoring_active = False
  monitor_thread.join()

if __name__ == '__main__':
  main()

Resolução de problemas

Esta secção fornece informações de resolução de problemas para ajudar a identificar e resolver problemas que possa encontrar ao usar a biblioteca de monitorização de TPUs.

Funcionalidades ou métricas em falta

Se não conseguir ver algumas funcionalidades ou métricas, a causa mais comum é uma versão desatualizada do libtpu. As funcionalidades e as métricas da biblioteca de monitorização de TPUs estão incluídas nas versões do libtpu, e as versões desatualizadas podem não ter novas funcionalidades e métricas.

Verifique a versão do libtpu que está a ser executada no seu ambiente:

Linha de comandos:

pip show libtpu

Python:

import libtpu

print(libtpu.__version__)

Se não estiver a usar a versão mais recente do libtpu, use o seguinte comando para atualizar a biblioteca:

pip install --upgrade libtpu