Biblioteca de monitorização de TPUs
Desbloqueie estatísticas detalhadas sobre o desempenho e o comportamento do hardware da Cloud TPU com capacidades de monitorização avançadas da TPU, criadas diretamente na camada de software fundamental, a LibTPU. Embora a LibTPU abranja controladores, bibliotecas de rede, o compilador XLA e o tempo de execução da TPU para interagir com as TPUs, o foco deste documento é a biblioteca de monitorização da TPU.
A biblioteca de monitorização de TPU oferece:
Observabilidade abrangente: aceda à API de telemetria e ao conjunto de métricas, que fornecem estatísticas detalhadas sobre o desempenho operacional e os comportamentos específicos das suas TPUs.
Kits de ferramentas de diagnóstico: fornece um SDK e uma interface de linhas de comando (CLI) concebidos para permitir a depuração e a análise detalhada do desempenho dos seus recursos de TPU.
Estas funcionalidades de monitorização foram concebidas para serem uma solução de nível superior orientada para o cliente, que lhe oferece as ferramentas essenciais para otimizar os seus encargos de trabalho de TPU de forma eficaz.
A biblioteca de monitorização da TPU fornece informações detalhadas sobre o desempenho das cargas de trabalho de aprendizagem automática no hardware da TPU. Foi concebida para ajudar a compreender a utilização da TPU, identificar gargalos e depurar problemas de desempenho. Dá-lhe informações mais detalhadas do que as métricas de interrupção, as métricas de débito útil e outras métricas.
Comece a usar a biblioteca de monitorização de TPUs
Aceder a estas estatísticas avançadas é simples. A funcionalidade de monitorização da TPU está integrada no SDK LibTPU, pelo que a funcionalidade está incluída quando instala o LibTPU.
Instale a LibTPU
pip install libtpu
Em alternativa, as atualizações da LibTPU são coordenadas com os lançamentos do JAX, o que significa que, quando instala o lançamento mais recente do JAX (lançado mensalmente), este fixa normalmente a versão mais recente da LibTPU compatível e as respetivas funcionalidades.
Instale o JAX
pip install -U "jax[tpu]"
Para os utilizadores do PyTorch, a instalação do PyTorch/XLA fornece a funcionalidade de monitorização do LibTPU e TPU mais recente.
Instale o PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Para mais informações sobre a instalação do PyTorch/XLA, consulte a secção Instalação no repositório do GitHub do PyTorch/XLA.
Importe a biblioteca no Python
Para começar a usar a biblioteca de monitorização de TPUs, tem de importar o módulo libtpu
no seu código Python.
from libtpu.sdk import tpumonitoring
Apresentar todas as funcionalidades suportadas
Indique todos os nomes das métricas e a funcionalidade que suportam:
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
Métricas compatíveis
O seguinte exemplo de código mostra como listar todos os nomes de métricas suportados:
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
A tabela seguinte mostra todas as métricas e as respetivas definições:
| Métrica | Definição | Nome da métrica para a API | Valores de exemplo |
|---|---|---|---|
| Utilização do núcleo Tensor | Mede a percentagem da utilização dos TensorCores, calculada como a percentagem de operações que fazem parte das operações dos TensorCores. Amostragem de 10 microssegundos a cada 1 segundo. Não pode modificar a taxa de amostragem. Esta métrica permite-lhe monitorizar a eficiência das suas cargas de trabalho em dispositivos TPU. |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# utilization percentage for accelerator ID 0-3. |
| Percentagem do ciclo de atividade | Percentagem do tempo durante o período de amostragem anterior (a cada 5 segundos; pode ser ajustada através da definição da flag LIBTPU_INIT_ARG) durante o qual o acelerador estava a processar ativamente (registado com os ciclos usados para executar programas HLO durante o último período de amostragem). Esta métrica representa o nível de ocupação de uma TPU. A métrica é emitida por chip.
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# Percentagem do ciclo de trabalho para o ID do acelerador 0-3. |
| HBM Capacity Total | Esta métrica indica a capacidade total de HBM em bytes. |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# Capacidade total de HBM em bytes associada ao ID do acelerador 0-3. |
| Utilização da capacidade de HBM | Esta métrica comunica a utilização da capacidade de HBM em bytes durante o período de amostragem anterior (a cada 5 segundos; pode ser ajustada através da definição da flag LIBTPU_INIT_ARG).
|
hbm_capacity_usage
|
['100', '200', '300', '400']
# Utilização da capacidade para HBM em bytes anexados ao ID do acelerador 0-3. |
| Latência de transferência da memória intermédia | Latências de transferência de rede para tráfego multislice de grande escala. Esta visualização permite-lhe compreender o ambiente de desempenho geral da rede. |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution. |
| Métricas de distribuição do tempo de execução de operações de nível elevado | Fornece estatísticas de desempenho detalhadas sobre o estado de execução do binário compilado HLO, o que permite a deteção de regressão e a depuração ao nível do modelo. |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# A distribuição da duração do tempo de execução do HLO para CoreType-CoreID com média, p50, p90, p95 e p999. |
| Tamanho da fila do otimizador de nível elevado | A monitorização do tamanho da fila de execução do HLO acompanha o número de programas HLO compilados que estão a aguardar ou em execução. Esta métrica revela o congestionamento do pipeline de execução, o que permite identificar restrições de desempenho na execução de hardware, sobrecarga do controlador ou atribuição de recursos. |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# Measures queue size for CoreType-CoreID. |
| Latência completa coletiva | Esta métrica mede a latência coletiva ponto a ponto na DCN em microssegundos, desde o anfitrião que inicia a operação até todos os pares receberem o resultado. Inclui a redução de dados no anfitrião e o envio de resultados para a TPU. Os resultados são strings que detalham o tamanho da memória intermédia, o tipo e as latências média, p50, p90, p95 e p99,9. |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency. |
| Latência de ida e volta na camada de transporte | Distribuição dos tempos de resposta (RTT) mínimos observados nas ligações TCP usadas pelo gRPC para tráfego de TPU de várias fatias. |
grpc_tcp_min_round_trip_times
|
['27.63, 29.03, 38.52, 41.63, 52.74']
# Representa a média, os percentis 50, 90, 95 e 99, 9 da distribuição em microssegundos (µs). |
| Tráfego transmitido na camada de transporte | Distribuição cumulativa da taxa de transferência recente de ligações TCP usadas pelo gRPC para tráfego de TPU de vários fragmentos. |
grpc_tcp_delivery_rates
|
['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']
# Representa a média, os percentis 50, 90, 95 e 99, 9 da distribuição em microssegundos (µs). |
Ler dados de métricas
Para ler dados de métricas, especifique o nome da métrica quando chamar a função tpumonitoring.get_metric. Pode inserir verificações de métricas ad hoc no código de baixo desempenho para identificar se os problemas de desempenho têm origem no software ou no hardware.
O seguinte exemplo de código mostra como ler a métrica duty_cycle:
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
Use métricas para verificar a utilização da TPU
Os exemplos seguintes mostram como usar métricas da biblioteca de monitorização de TPUs para acompanhar a utilização de TPUs.
Monitorize o ciclo de serviço da TPU durante a preparação do JAX
Cenário: está a executar um script de preparação do JAX e quer monitorizar a métrica duty_cycle_pct da TPU ao longo do processo de preparação para confirmar que as TPUs estão a ser usadas de forma eficaz. Pode registar esta métrica periodicamente durante a preparação para monitorizar a utilização da TPU.
O seguinte exemplo de código mostra como monitorizar o ciclo de serviço da TPU durante o treino do JAX:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
Verifique a utilização da HBM antes de executar a inferência JAX
Cenário: antes de executar a inferência com o seu modelo JAX, verifique a utilização atual da HBM (memória de largura de banda elevada) na TPU para confirmar que tem memória suficiente disponível e para obter uma medição de base antes de iniciar a inferência.
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
Verifique as métricas de rede
Cenário: está a executar uma carga de trabalho com vários anfitriões e várias fatias e quer estabelecer ligação a um dos pods do GKE ou às TPUs através de SSH para ver as métricas de rede enquanto a carga de trabalho está em execução. Os comandos também podem ser incorporados diretamente na carga de trabalho com vários anfitriões.
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and training setup goes here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
# --- details here ---
# ==============================================================================
# Metric 1: TCP Delivery Rate
# ==============================================================================
# This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)
# Get the metric object
delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)
# Print the description provided by the library
print("Description:", delivery_rate_metric.description())
# Print the actual data payload
print("Data:", delivery_rate_metric.data())
# ==============================================================================
# Metric 2: TCP Minimum Round Trip Time (RTT)
# ==============================================================================
# This metric reflects the minimum RTT measured between sending a TCP packet
# and receiving the acknowledgement.
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)
# Get the metric object
min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)
# Print the description provided by the library
print("Description:", min_rtt_metric.description())
# Print the actual data payload
print("Data:", min_rtt_metric.data())
Frequência de atualização das métricas de TPU
A frequência de atualização das métricas da TPU está limitada a um mínimo de um segundo. Os dados de métricas do anfitrião são exportados a uma frequência fixa de 1 Hz. A latência introduzida por este processo de exportação é insignificante. As métricas de tempo de execução da LibTPU não estão sujeitas à mesma restrição de frequência. No entanto, para manter a consistência, estas métricas também são amostradas a 1 Hz ou 1 amostra por segundo.
Módulo TPU-Z
O TPU-Z é uma funcionalidade de telemetria e depuração para TPUs. Fornece informações detalhadas sobre o estado de tempo de execução para todos os núcleos da TPU anexados a um anfitrião. A funcionalidade é fornecida através do módulo tpuz, que faz parte do módulo libtpu.sdk no SDK Python libtpu. O módulo fornece uma vista geral do estado de cada núcleo.
O exemplo de utilização principal do TPU-Z é o diagnóstico de bloqueios ou impasses em cargas de trabalho de TPU distribuídas. Pode consultar o serviço TPU-Z em anfitriões para capturar o estado de cada núcleo, comparando os contadores de programas, as localizações HLO e os IDs de execução em todos os núcleos para identificar anomalias.
Use a função get_core_state_summary() na biblioteca libtpu.sdk para apresentar as métricas de TPU-Z:
summary = sdk.tpuz.get_core_state_summary()
O resultado das métricas de TPU-Z é fornecido como um dicionário. Segue-se um exemplo reduzido para um único núcleo:
{
"host_name": "my-tpu-host-vm",
"core_states": {
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 4490,
"program_id": 3274167277388825310,
"run_id": 3
}
],
"program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
"queued_program_info": [],
"error_message": ""
}
// ...
}
}
Para obter informações sobre os otimizadores de nível elevado (HLO) em cada núcleo, defina o parâmetro include_hlo_info como True:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
A saída inclui informações de HLO adicionais:
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 17776,
"tag": 3,
"tracemark": 2147483646,
"program_id": 3230481660274331500,
"run_id": 81,
"hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
"hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
}
],
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
"launch_id": 1394130914,
"queued_program_info": [
{
"run_id": 81,
"launch_id": 1394130914,
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
}
]
}
Métricas do TPU-Z
A função get_core_state_summary devolve métricas de TPU-Z sob a forma de um dicionário com a seguinte estrutura.
CurrentCoreStateSummary
O dicionário CurrentCoreStateSummary fornece um resumo detalhado do estado de um núcleo de TPU individual.
| Campo | Tipo | Descrição |
|---|---|---|
core_id |
dicionário | Um TpuCoreIdentifierdicionário que contém informações de ID sobre o núcleo da TPU. |
sequencer_info |
lista de dicionários | Uma lista de dicionários SequencerInfo, que descrevem o estado de cada sequenciador no núcleo. |
program_fingerprint |
bytes | A impressão digital do programa que está a ser executado neste núcleo. |
launch_id |
número inteiro | O ID de lançamento do programa atual ou mais recente. |
queued_program_info |
lista de dicionários | Uma lista de QueuedProgramInfo dicionários para programas colocados em fila para execução. |
error_message |
de string | Quaisquer mensagens de erro para este núcleo. |
TpuCoreIdentifier
O dicionário TpuCoreIdentifier fornece informações de ID para núcleos no sistema de TPU.
| Campo | Tipo | Descrição |
|---|---|---|
global_core_id |
número inteiro | O ID do núcleo. |
chip_id |
número inteiro | O ID do chip ao qual o núcleo pertence. |
core_on_chip |
dicionário | Um TpuCoreOnChip dicionário que descreve o tipo do núcleo e o respetivo índice no chip. |
TpuCoreOnChip
O dicionário TpuCoreOnChip contém informações sobre as propriedades de um núcleo
num chip específico.
| Campo | Tipo | Descrição |
|---|---|---|
type |
de string | O tipo de núcleo da TPU. Por exemplo: TPU_CORE_TYPE_TENSOR_CORE. |
index |
número inteiro | O índice do núcleo no chip. |
SequencerInfo
O dicionário SequencerInfo contém informações sobre o estado de um único sequenciador num núcleo.
| Campo | Tipo | Descrição |
|---|---|---|
sequencer_type |
de string | O tipo de sequenciador. Por exemplo: TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER. |
sequencer_index |
número inteiro | O índice do sequenciador (se existirem vários do mesmo tipo). |
pc |
número inteiro | O valor atual do contador do programa. |
program_id |
número inteiro | O ID associado a uma instância específica de um programa que está a ser iniciado para execução num núcleo da TPU. |
run_id |
número inteiro | O ID de execução associado a uma instância específica da execução de um programa num núcleo da TPU. |
hlo_location |
de string | Informações de localização do otimizador de nível superior. |
hlo_detailed_info |
de string | Informações detalhadas do otimizador de nível superior. |
QueuedProgramInfo
O dicionário QueuedProgramInfo contém informações sobre programas em fila de espera
para execução num núcleo.
| Campo | Tipo | Descrição |
|---|---|---|
run_id |
número inteiro | O ID de execução do programa em fila. |
launch_id |
número inteiro | O ID de lançamento do programa em fila. |
program_fingerprint |
bytes | A impressão digital do programa em fila. |
TPU-Z com JAX
Pode aceder às métricas do TPU-Z em cargas de trabalho JAX através da biblioteca libtpu.sdk.
O seguinte script Python usa o JAX para a computação de tensores de elevado desempenho,
enquanto usa simultaneamente o SDK libtpu num processo em segundo plano para monitorizar
o estado e a atividade do hardware da TPU subjacente.
Inclua os seguintes pacotes Python:
import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk
A função monitor_tpu_status usa um tópico em segundo plano para mostrar continuamente o estado operacional dos núcleos das TPUs enquanto a aplicação principal executa uma carga de trabalho JAX. Funciona como uma ferramenta de diagnóstico em tempo real.
def monitor_tpu_status():
"""Monitors TPU status in a background thread."""
while monitoring_active:
try:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
if summary and 'core_states' in summary:
print(summary)
else:
print('WARNING: Call returned an empty or invalid summary.')
except RuntimeError as e:
print(f'FAIL: Error calling API: {e}')
except Exception as e:
print(f'FAIL: Unexpected error in monitor thread: {e}')
for _ in range(MONITORING_INTERVAL_SECONDS * 2):
if not monitoring_active:
break
time.sleep(0.5)
print('✅ Monitoring thread stopped.')
A função transformer_block implementa uma camada completa da arquitetura Transformer, que é a base fundamental para os MDIs.
@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
"""A simplified but computationally intensive Transformer block."""
# Multi-head Self-Attention
qkv = jnp.dot(x, params['qkv_kernel'])
q, k, v = jnp.array_split(qkv, 3, axis=-1)
# Reshape for multi-head attention
q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
# Scaled dot-product attention
attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)
attention_output = jnp.dot(attention_output, params['o_kernel'])
# Residual connection and Layer Normalization 1
h1 = x + attention_output
h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
h1_norm = h1_norm / jnp.sqrt(
jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
)
# Feed-Forward Network
ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])
# Residual connection and Layer Normalization 2
h2 = h1_norm + ffn_output
h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
h2_norm = h2_norm / jnp.sqrt(
jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
)
return h2_norm
A função main orquestra a configuração da computação JAX, inicia a monitorização da TPU em segundo plano e executa o ciclo de carga de trabalho principal.
def main():
num_devices = jax.device_count()
print(f"Running on {num_devices} devices.")
batch_size = 128 * num_devices
seq_len = 512
embed_dim = 1024
ffn_dim = embed_dim * 4
key = jax.random.PRNGKey(0)
params = {
'qkv_kernel': jax.random.normal(
key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
),
'o_kernel': jax.random.normal(
key, (embed_dim, embed_dim), dtype=jnp.bfloat16
),
'ffn1_kernel': jax.random.normal(
key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
),
'ffn2_kernel': jax.random.normal(
key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
),
}
input_data = jax.random.normal(
key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
)
input_data = jax.device_put(input_data)
monitor_thread = threading.Thread(target=monitor_tpu_status)
monitor_thread.start()
print("Starting JAX computation loop...")
start_time = time.time()
iterations = 0
while time.time() - start_time < JOB_DURATION_SECONDS:
result = transformer_block(params, input_data)
result.block_until_ready()
iterations += 1
print(f' -> Jax iteration {iterations} complete.', end='\r')
print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")
global monitoring_active
monitoring_active = False
monitor_thread.join()
if __name__ == '__main__':
main()
Resolução de problemas
Esta secção fornece informações de resolução de problemas para ajudar a identificar e resolver problemas que possa encontrar ao usar a biblioteca de monitorização de TPUs.
Funcionalidades ou métricas em falta
Se não conseguir ver algumas funcionalidades ou métricas, a causa mais comum é uma versão desatualizada do libtpu. As funcionalidades e as métricas da biblioteca de monitorização de TPUs estão incluídas nas versões do libtpu, e as versões desatualizadas podem não ter novas funcionalidades e métricas.
Verifique a versão do libtpu que está a ser executada no seu ambiente:
Linha de comandos:
pip show libtpu
Python:
import libtpu
print(libtpu.__version__)
Se não estiver a usar a versão mais recente do
libtpu, use o seguinte comando para atualizar a biblioteca:
pip install --upgrade libtpu