Biblioteca de monitoramento de TPUs

Tenha acesso a insights detalhados sobre o desempenho e o comportamento do hardware do Cloud TPU com recursos avançados de monitoramento, criados diretamente na camada de software de base, a LibTPU. Embora a LibTPU inclua drivers, bibliotecas de rede, o compilador XLA e o ambiente de execução para interagir com TPUs, o foco deste documento é a biblioteca de monitoramento de TPUs.

A biblioteca de monitoramento de TPUs oferece:

  • Observabilidade abrangente: tenha acesso à API de telemetria e ao conjunto de métricas. Com isso, você tem acesso a insights detalhados sobre o desempenho operacional e comportamentos específicos das TPUs.

  • Kits de ferramentas de diagnóstico: fornecem um SDK e uma interface de linha de comando (CLI) projetados para permitir a depuração e a análise detalhada de desempenho dos recursos de TPU.

Esses recursos de monitoramento foram projetados para oferecer uma solução de alto nível voltada ao cliente, com as ferramentas essenciais para otimizar cargas de trabalho de TPU de maneira eficaz.

A biblioteca de monitoramento de TPUs oferece informações detalhadas sobre o desempenho das cargas de trabalho de machine learning no hardware de TPU. Ela foi projetada para ajudar você a entender o uso da TPU, identificar gargalos e depurar problemas de desempenho. Além disso, ela oferece informações mais detalhadas do que as métricas de interrupção, de goodput e outras.

Introdução à biblioteca de monitoramento de TPUs

É fácil acessar esses insights valiosos. Como a funcionalidade de monitoramento de TPU é integrada ao SDK LibTPU, ela é incluída quando você instala a LibTPU.

Instalar a LibTPU

pip install libtpu

As atualizações da LibTPU são coordenadas com os lançamentos do JAX. Isso significa que, ao instalar a versão mais recente do JAX (lançada mensalmente), você geralmente tem acesso à versão mais recente disponível da LibTPU e aos recursos dela.

Instalar o JAX

pip install -U "jax[tpu]"

Para usuários do PyTorch, a instalação do PyTorch/XLA oferece a funcionalidade mais recente da LibTPU e do monitoramento de TPUs.

Instalar o PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Para saber como instalar o PyTorch/XLA, consulte Instalação no repositório do GitHub do PyTorch/XLA.

Importar a biblioteca em Python

Para começar a usar a biblioteca de monitoramento de TPUs, importe o módulo libtpu no código Python.

from libtpu.sdk import tpumonitoring

Listar todas as funcionalidades disponíveis

Liste todos os nomes de métricas e a funcionalidade disponível:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

Métricas disponíveis

O exemplo de código abaixo mostra como listar todos os nomes de métricas aceitos:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

A tabela abaixo mostra todas as métricas e as definições correspondentes:

Métrica Definição Nome da métrica para a API Exemplos de valores
Utilização do TensorCore Mede a porcentagem de uso do TensorCore, calculada como a porcentagem de operações que fazem parte das operações do TensorCore. Amostras de 10 microssegundos são coletadas a cada 1 segundo. Não é possível modificar a taxa de amostragem. Com essa métrica, é possível monitorar a eficiência das cargas de trabalho nos dispositivos de TPU. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# Porcentagem de utilização para o ID de acelerador 0-3
Porcentagem do ciclo de trabalho Porcentagem de tempo durante o período de amostra anterior (a cada cinco segundos; pode ser ajustada definindo a flag LIBTPU_INIT_ARG) em que o acelerador estava realizando processamento ativo (registrado com ciclos usados para executar programas HLO durante o último período de amostragem). Essa métrica representa o nível de atividade de uma TPU e é emitida por chip. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Porcentagem do ciclo de trabalho para o ID de acelerador 0-3
Capacidade total de HBM Essa métrica informa a capacidade total de HBM em bytes. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Capacidade total de HBM em bytes associada ao ID de acelerador 0-3
Uso da capacidade de HBM Essa métrica informa o uso da capacidade de HBM em bytes no último período de amostragem (a cada cinco segundos; pode ser ajustada definindo a flag LIBTPU_INIT_ARG). hbm_capacity_usage ['100', '200', '300', '400']

# Utilização da capacidade de HBM em bytes associada ao ID de acelerador 0-3
Latência de transferência de buffer Latências de transferência de rede para tráfego de várias frações em escala massiva. Essa visualização permite entender o ambiente geral de desempenho da rede. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# tamanho do buffer, média, p50, p90, p99, p99.9 da distribuição da latência de transferência de rede
Métricas de distribuição do tempo de execução geral de operações Fornece insights de desempenho granulares sobre o status de execução do binário compilado de HLO para ajudar na detecção de regressão e na depuração em nível de modelo. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# Distribuição da duração do tempo de execução de HLO para CoreType-CoreID com média, p50, p90, p95 e p999
Tamanho geral da fila do otimizador O monitoramento do tamanho da fila de execução de HLO acompanha o número de programas HLO compilados que estão em execução ou aguardando. Essa métrica revela o congestionamento do pipeline de execução, o que ajuda na identificação de gargalos de desempenho na execução do hardware, no overhead do driver ou na alocação de recursos. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Mede o tamanho da fila para CoreType-CoreID.
Latência coletiva de ponta a ponta Essa métrica mede a latência coletiva de ponta a ponta na DCN em microssegundos, desde o início da operação pelo host até o recebimento da saída por todos os peers. Isso inclui a redução de dados do host e o envio da saída para a TPU. Os resultados são strings que detalham o tamanho, o tipo e as latências média, p50, p90, p95 e p99,9 do buffer. collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Tamanho da transferência - operação de comunicação coletiva, média, p50, p90, p95, p999 da latência de ponta a ponta da operação de comunicação coletiva

Ler dados de métricas: modo de snapshot

Para ativar o modo de snapshot, especifique o nome da métrica ao chamar a função tpumonitoring.get_metric. O modo de snapshot permite inserir verificações de métricas sob demanda em códigos de baixo desempenho para identificar se os problemas de desempenho são causados pelo software ou pelo hardware.

O exemplo de código a seguir mostra como usar o modo de snapshot para ler o duty_cycle.

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

Acessar métricas usando a CLI

As etapas abaixo mostram como interagir com as métricas da LibTPU usando a CLI:

  1. Instale tpu-info:

    pip install tpu-info
    
    
    # Access help information of tpu-info
    tpu-info --help / -h
    
    
  2. Execute a versão padrão de tpu-info:

    tpu-info
    

    O resultado será assim:

   TPU Chips
   ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
    Chip         Type         Devices  PID       ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
    /dev/accel0  TPU v4 chip  1        130007     /dev/accel1  TPU v4 chip  1        130007     /dev/accel2  TPU v4 chip  1        130007     /dev/accel3  TPU v4 chip  1        130007    └─────────────┴─────────────┴─────────┴────────┘

   TPU Runtime Utilization
   ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    Device  Memory usage          Duty cycle    ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    0       0.00 GiB / 31.75 GiB       0.00%     1       0.00 GiB / 31.75 GiB       0.00%     2       0.00 GiB / 31.75 GiB       0.00%     3       0.00 GiB / 31.75 GiB       0.00%    └────────┴──────────────────────┴────────────┘

   TensorCore Utilization
   ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
    Chip ID  TensorCore Utilization    ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
    0                         0.00%     1                         0.00%     3                         0.00%     2                         0.00% |
   └─────────┴────────────────────────┘

   Buffer Transfer Latency
   ┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
    Buffer Size  P50  P90  P95  P999    ┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
          8MB+  | 0us  0us  0us   0us |
   └─────────────┴─────┴─────┴─────┴──────┘

Usar métricas para verificar a utilização de TPUs

Os exemplos a seguir mostram como usar métricas da biblioteca de monitoramento de TPUs para acompanhar a utilização das TPUs.

Monitorar o ciclo de trabalho da TPU durante o treinamento do JAX

Cenário: você está executando um script de treinamento do JAX e quer monitorar a métrica duty_cycle_pct das TPUs durante todo o processo de treinamento para confirmar se elas estão sendo usadas de maneira eficaz. É possível registrar essa métrica periodicamente durante o treinamento para acompanhar a utilização das TPUs.

O exemplo de código abaixo mostra como monitorar o ciclo de trabalho das TPUs durante o treinamento do JAX:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

Verificar a utilização de HBM antes de executar a inferência do JAX

Cenário: antes de executar a inferência com o modelo do JAX, verifique a utilização atual de HBM (memória de alta largura de banda) na TPU para confirmar se você tem memória suficiente disponível e para receber uma medição de base antes do início da inferência.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

Frequência de exportação das métricas de TPU

A frequência de atualização das métricas de TPU é limitada a um mínimo de um segundo. Os dados de métricas do host são exportados em uma frequência fixa de 1 Hz. A latência introduzida por esse processo de exportação é insignificante. As métricas de ambiente de execução da LibTPU não estão sujeitas à mesma restrição de frequência. No entanto, para fins de consistência, essas métricas também são amostradas a 1 Hz ou geram uma amostra por segundo.