Biblioteca de monitorización de TPU
Obtén información valiosa sobre el rendimiento y el comportamiento del hardware de tu TPU de Cloud con las funciones de monitorización avanzada de TPUs, creadas directamente sobre la capa de software fundamental, LibTPU. Aunque LibTPU incluye controladores, bibliotecas de redes, el compilador XLA y el tiempo de ejecución de TPU para interactuar con las TPUs, este documento se centra en la biblioteca de monitorización de TPUs.
La biblioteca de monitorización de TPU proporciona lo siguiente:
Observabilidad completa: accede a la API de telemetría y al conjunto de métricas, que proporcionan información detallada sobre el rendimiento operativo y los comportamientos específicos de tus TPUs.
Kits de herramientas de diagnóstico: proporciona un SDK y una interfaz de línea de comandos (CLI) diseñados para permitir la depuración y el análisis detallado del rendimiento de tus recursos de TPU.
Estas funciones de monitorización se han diseñado para ser una solución de alto nivel orientada al cliente, que te proporciona las herramientas esenciales para optimizar tus cargas de trabajo de TPU de forma eficaz.
La biblioteca de monitorización de TPUs te proporciona información detallada sobre el rendimiento de las cargas de trabajo de aprendizaje automático en el hardware de TPUs. Está diseñada para ayudarte a entender el uso de tus TPUs, identificar cuellos de botella y depurar problemas de rendimiento. Proporciona información más detallada que las métricas de interrupción, las métricas de buen rendimiento y otras métricas.
Empezar a usar la biblioteca de monitorización de TPU
Acceder a estas valiosas estadísticas es muy sencillo. La función de monitorización de TPUs está integrada en el SDK de LibTPU, por lo que se incluye al instalar LibTPU.
Instalar LibTPU
pip install libtpu
Por otro lado, las actualizaciones de LibTPU se coordinan con los lanzamientos de JAX, lo que significa que, cuando instales la última versión de JAX (que se lanza mensualmente), normalmente se te asignará la última versión compatible de LibTPU y sus funciones.
Instalar JAX
pip install -U "jax[tpu]"
Los usuarios de PyTorch pueden instalar PyTorch/XLA para acceder a las funciones más recientes de LibTPU y de monitorización de TPU.
Instalar PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Para obtener más información sobre cómo instalar PyTorch/XLA, consulta la sección Instalación del repositorio de GitHub de PyTorch/XLA.
Importar la biblioteca en Python
Para empezar a usar la biblioteca de monitorización de TPU, debes importar el módulo libtpu en tu código de Python.
from libtpu.sdk import tpumonitoring
Lista de todas las funciones admitidas
Lista de todos los nombres de métricas y las funciones que admiten:
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
Métricas admitidas
En el siguiente código de ejemplo se muestra cómo obtener una lista de todos los nombres de métricas admitidos:
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
En la siguiente tabla se muestran todas las métricas y sus definiciones correspondientes:
| Métrica | Definición | Nombre de la métrica de la API | Valores de ejemplo |
|---|---|---|---|
| Utilización de Tensor Core | Mide el porcentaje de uso de Tensor Core, que se calcula como el porcentaje de operaciones que forman parte de las operaciones de Tensor Core. Se muestrea cada segundo durante 10 microsegundos. No puedes modificar la frecuencia de muestreo. Esta métrica te permite monitorizar la eficiencia de tus cargas de trabajo en dispositivos TPU. |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# porcentaje de utilización de los IDs de acelerador del 0 al 3. |
| Porcentaje del ciclo de trabajo | Porcentaje del tiempo durante el último periodo de muestreo (cada 5 segundos; se puede ajustar configurando la marca LIBTPU_INIT_ARG) en el que el acelerador ha procesado activamente (registrado con los ciclos utilizados para ejecutar programas HLO durante el último periodo de muestreo). Esta métrica representa el nivel de actividad de una TPU. La métrica se emite por chip.
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# Porcentaje del ciclo de actividad del acelerador con ID del 0 al 3. |
| Capacidad total de HBM | Esta métrica indica la capacidad total de HBM en bytes. |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# Capacidad total de HBM en bytes conectada a los IDs de acelerador del 0 al 3. |
| Uso de la capacidad de HBM | Esta métrica informa del uso de la capacidad de HBM en bytes durante el periodo de muestreo anterior (cada 5 segundos; se puede ajustar configurando la marca LIBTPU_INIT_ARG).
|
hbm_capacity_usage
|
['100', '200', '300', '400']
# Uso de la capacidad de HBM en bytes que se ha adjuntado al ID de acelerador 0-3. |
| Latencia de transferencia de búfer | Latencias de transferencia de red para tráfico multisegmento a gran escala. Esta visualización te permite comprender el entorno de rendimiento general de la red. |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# Tamaño del búfer, media, p50, p90, p99 y p99, 9 de la distribución de la latencia de transferencia de red. |
| Métricas de distribución del tiempo de ejecución de operaciones de alto nivel | Proporciona estadísticas de rendimiento detalladas sobre el estado de ejecución del archivo binario compilado de HLO, lo que permite detectar regresiones y depurar a nivel de modelo. |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# Distribución de la duración del tiempo de ejecución de HLO para CoreType-CoreID con la media, p50, p90, p95 y p999. |
| Tamaño de la cola del optimizador de alto nivel | La monitorización del tamaño de la cola de ejecución de HLO registra el número de programas HLO compilados que están esperando o en proceso de ejecución. Esta métrica muestra la congestión de la canalización de ejecución, lo que permite identificar cuellos de botella en el rendimiento de la ejecución de hardware, la sobrecarga de controladores o la asignación de recursos. |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# Mide el tamaño de la cola de CoreType-CoreID. |
| Latencia integral colectiva | Esta métrica mide la latencia colectiva de extremo a extremo en la DCN en microsegundos, desde que el host inicia la operación hasta que todos los peers reciben el resultado. Incluye la reducción de datos del host y el envío de la salida a la TPU. Los resultados son cadenas que detallan el tamaño del búfer, el tipo y las latencias media, p50, p90, p95 y p99,9. |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# Operación colectiva de tamaño de transferencia, media, p50, p90, p95 y p999 de la latencia colectiva de extremo a extremo. |
| Latencia de ida y vuelta en la capa de transporte | Distribución de los tiempos de ida y vuelta mínimos observados en las conexiones TCP que usa gRPC para el tráfico de TPU multisegmento. |
grpc_tcp_min_round_trip_times
|
['27.63, 29.03, 38.52, 41.63, 52.74']
# Representa la media, el percentil 50, el 90, el 95 y el 99, 9 de la distribución en microsegundos (µs). |
| Velocidad de transferencia en la capa de transporte | Distribución acumulativa del rendimiento reciente de las conexiones TCP usadas por gRPC para el tráfico de TPU multisegmento. |
grpc_tcp_delivery_rates
|
['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']
# Representa la media, el percentil 50, el 90, el 95 y el 99, 9 de la distribución en microsegundos (µs). |
Leer datos de métricas
Para leer datos de métricas, especifica el nombre de la métrica al llamar a la función tpumonitoring.get_metric. Puede insertar comprobaciones de métricas ad hoc en código de bajo rendimiento para identificar si los problemas de rendimiento se deben al software o al hardware.
En el siguiente código de ejemplo se muestra cómo leer la métrica duty_cycle:
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
Usar métricas para comprobar la utilización de la TPU
En los siguientes ejemplos se muestra cómo usar las métricas de la biblioteca de monitorización de TPUs para hacer un seguimiento del uso de las TPUs.
Monitorizar el ciclo de trabajo de la TPU durante el entrenamiento de JAX
Situación: estás ejecutando una secuencia de comandos de entrenamiento de JAX y quieres monitorizar la métrica duty_cycle_pct de la TPU durante todo el proceso de entrenamiento para confirmar que tus TPUs se están utilizando de forma eficaz. Puedes registrar esta métrica periódicamente durante el entrenamiento para monitorizar el uso de la TPU.
En el siguiente código de ejemplo se muestra cómo monitorizar el ciclo de trabajo de la TPU durante el entrenamiento de JAX:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
Comprobar la utilización de HBM antes de ejecutar la inferencia de JAX
Situación: Antes de ejecutar la inferencia con tu modelo de JAX, comprueba el uso actual de la memoria de alto ancho de banda (HBM) en la TPU para confirmar que tienes suficiente memoria disponible y obtener una medición de referencia antes de que empiece la inferencia.
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
Consultar métricas de red
Situación: estás ejecutando una carga de trabajo de varios hosts y varios slices y quieres conectarte a uno de los pods o las TPUs de GKE mediante SSH para ver las métricas de red mientras se ejecuta la carga de trabajo. Los comandos también se pueden incorporar directamente a la carga de trabajo de varios hosts.
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and training setup goes here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
# --- details here ---
# ==============================================================================
# Metric 1: TCP Delivery Rate
# ==============================================================================
# This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)
# Get the metric object
delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)
# Print the description provided by the library
print("Description:", delivery_rate_metric.description())
# Print the actual data payload
print("Data:", delivery_rate_metric.data())
# ==============================================================================
# Metric 2: TCP Minimum Round Trip Time (RTT)
# ==============================================================================
# This metric reflects the minimum RTT measured between sending a TCP packet
# and receiving the acknowledgement.
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)
# Get the metric object
min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)
# Print the description provided by the library
print("Description:", min_rtt_metric.description())
# Print the actual data payload
print("Data:", min_rtt_metric.data())
Frecuencia de actualización de las métricas de TPU
La frecuencia de actualización de las métricas de TPU está limitada a un mínimo de un segundo. Los datos de métricas de host se exportan con una frecuencia fija de 1 Hz. La latencia introducida por este proceso de exportación es insignificante. Las métricas de tiempo de ejecución de LibTPU no están sujetas a la misma restricción de frecuencia. Sin embargo, para mantener la coherencia, estas métricas también se muestrean a 1 Hz o 1 muestra por segundo.
Módulo TPU-Z
TPU-Z es una herramienta de telemetría y depuración para TPUs. Proporciona información detallada sobre el estado del tiempo de ejecución de todos los núcleos de TPU conectados a un host. Esta funcionalidad se proporciona a través del módulo tpuz, que forma parte del módulo libtpu.sdk del SDK de Python libtpu. El módulo proporciona una vista general del estado de cada núcleo.
El principal caso práctico de TPU-Z es diagnosticar bloqueos o interbloqueos en cargas de trabajo de TPU distribuidas. Puedes consultar el servicio TPU-Z en los hosts para registrar el estado de cada núcleo, comparar los contadores de programa, las ubicaciones de HLO y los IDs de ejecución de todos los núcleos para identificar anomalías.
Usa la función get_core_state_summary() de la biblioteca libtpu.sdk para mostrar las métricas de TPU-Z:
summary = sdk.tpuz.get_core_state_summary()
La salida de las métricas de TPU-Z se proporciona como un diccionario. A continuación, se muestra un ejemplo abreviado de un solo núcleo:
{
"host_name": "my-tpu-host-vm",
"core_states": {
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 4490,
"program_id": 3274167277388825310,
"run_id": 3
}
],
"program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
"queued_program_info": [],
"error_message": ""
}
// ...
}
}
Para obtener información sobre los optimizadores de alto nivel (HLO) de cada núcleo, asigna el valor True al parámetro include_hlo_info:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
La salida incluye información adicional de HLO:
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 17776,
"tag": 3,
"tracemark": 2147483646,
"program_id": 3230481660274331500,
"run_id": 81,
"hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
"hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
}
],
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
"launch_id": 1394130914,
"queued_program_info": [
{
"run_id": 81,
"launch_id": 1394130914,
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
}
]
}
Métricas de TPU-Z
La función get_core_state_summary devuelve métricas de TPU-Z en forma de diccionario con la siguiente estructura.
CurrentCoreStateSummary
El diccionario CurrentCoreStateSummary proporciona un resumen detallado del estado de un núcleo de TPU concreto.
| Campo | Tipo | Descripción |
|---|---|---|
core_id |
diccionario | Un diccionario TpuCoreIdentifier que contiene información de ID sobre el núcleo de la TPU. |
sequencer_info |
lista de diccionarios | Lista de diccionarios SequencerInfo que describen el estado de cada secuenciador del núcleo. |
program_fingerprint |
bytes | Huella digital del programa que se está ejecutando en este núcleo. |
launch_id |
entero | ID de lanzamiento del programa actual o más reciente. |
queued_program_info |
lista de diccionarios | Una lista de diccionarios QueuedProgramInfo para los programas que están en cola para ejecutarse. |
error_message |
cadena | Mensajes de error de este elemento principal. |
TpuCoreIdentifier
El diccionario TpuCoreIdentifier proporciona información de ID de los núcleos del sistema de TPU.
| Campo | Tipo | Descripción |
|---|---|---|
global_core_id |
entero | Es el ID del elemento principal. |
chip_id |
entero | El ID del chip al que pertenece el núcleo. |
core_on_chip |
diccionario | Un diccionario TpuCoreOnChip que describe el tipo del núcleo y su índice en el chip. |
TpuCoreOnChip
El diccionario TpuCoreOnChip contiene información sobre las propiedades de un núcleo
de un chip específico.
| Campo | Tipo | Descripción |
|---|---|---|
type |
cadena | Tipo de núcleo de TPU. Por ejemplo: TPU_CORE_TYPE_TENSOR_CORE. |
index |
entero | Índice del núcleo del chip. |
SequencerInfo
El diccionario SequencerInfo contiene información sobre el estado de un secuenciador de un núcleo.
| Campo | Tipo | Descripción |
|---|---|---|
sequencer_type |
cadena | El tipo de secuenciador. Por ejemplo: TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER. |
sequencer_index |
entero | Índice del secuenciador (si hay varios del mismo tipo). |
pc |
entero | El valor actual del contador de programa. |
program_id |
entero | El ID asociado a una instancia específica de un programa que se inicia para ejecutarse en un núcleo de TPU. |
run_id |
entero | El ID de ejecución asociado a una instancia específica de la ejecución de un programa en un core de TPU. |
hlo_location |
cadena | Información de ubicación de High Level Optimizer. |
hlo_detailed_info |
cadena | Información detallada del optimizador de alto nivel. |
QueuedProgramInfo
El diccionario QueuedProgramInfo contiene información sobre los programas puestos en cola
para ejecutarse en un núcleo.
| Campo | Tipo | Descripción |
|---|---|---|
run_id |
entero | ID de ejecución del programa en cola. |
launch_id |
entero | Es el ID de lanzamiento del programa en cola. |
program_fingerprint |
bytes | La huella digital del programa en cola. |
TPU-Z con JAX
Puedes acceder a las métricas de TPU-Z en cargas de trabajo de JAX a través de la biblioteca libtpu.sdk.
La siguiente secuencia de comandos de Python usa JAX para realizar cálculos de tensores de alto rendimiento, al tiempo que usa el SDK de libtpu en un subproceso en segundo plano para monitorizar el estado y la actividad del hardware de TPU subyacente.
Incluye los siguientes paquetes de Python:
import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk
La función monitor_tpu_status usa un subproceso en segundo plano para mostrar continuamente el estado operativo de los núcleos de las TPUs mientras la aplicación principal ejecuta una carga de trabajo de JAX. Actúa como una herramienta de diagnóstico en tiempo real.
def monitor_tpu_status():
"""Monitors TPU status in a background thread."""
while monitoring_active:
try:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
if summary and 'core_states' in summary:
print(summary)
else:
print('WARNING: Call returned an empty or invalid summary.')
except RuntimeError as e:
print(f'FAIL: Error calling API: {e}')
except Exception as e:
print(f'FAIL: Unexpected error in monitor thread: {e}')
for _ in range(MONITORING_INTERVAL_SECONDS * 2):
if not monitoring_active:
break
time.sleep(0.5)
print('✅ Monitoring thread stopped.')
La función transformer_block implementa una capa completa de la arquitectura Transformer, que es el componente básico de los LLMs.
@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
"""A simplified but computationally intensive Transformer block."""
# Multi-head Self-Attention
qkv = jnp.dot(x, params['qkv_kernel'])
q, k, v = jnp.array_split(qkv, 3, axis=-1)
# Reshape for multi-head attention
q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
# Scaled dot-product attention
attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)
attention_output = jnp.dot(attention_output, params['o_kernel'])
# Residual connection and Layer Normalization 1
h1 = x + attention_output
h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
h1_norm = h1_norm / jnp.sqrt(
jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
)
# Feed-Forward Network
ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])
# Residual connection and Layer Normalization 2
h2 = h1_norm + ffn_output
h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
h2_norm = h2_norm / jnp.sqrt(
jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
)
return h2_norm
La función main coordina la configuración del cálculo de JAX, inicia la monitorización de TPU en segundo plano y ejecuta el bucle de carga de trabajo principal.
def main():
num_devices = jax.device_count()
print(f"Running on {num_devices} devices.")
batch_size = 128 * num_devices
seq_len = 512
embed_dim = 1024
ffn_dim = embed_dim * 4
key = jax.random.PRNGKey(0)
params = {
'qkv_kernel': jax.random.normal(
key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
),
'o_kernel': jax.random.normal(
key, (embed_dim, embed_dim), dtype=jnp.bfloat16
),
'ffn1_kernel': jax.random.normal(
key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
),
'ffn2_kernel': jax.random.normal(
key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
),
}
input_data = jax.random.normal(
key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
)
input_data = jax.device_put(input_data)
monitor_thread = threading.Thread(target=monitor_tpu_status)
monitor_thread.start()
print("Starting JAX computation loop...")
start_time = time.time()
iterations = 0
while time.time() - start_time < JOB_DURATION_SECONDS:
result = transformer_block(params, input_data)
result.block_until_ready()
iterations += 1
print(f' -> Jax iteration {iterations} complete.', end='\r')
print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")
global monitoring_active
monitoring_active = False
monitor_thread.join()
if __name__ == '__main__':
main()
Solución de problemas
En esta sección se proporciona información para ayudarte a identificar y resolver los problemas que puedan surgir al usar la biblioteca de monitorización de TPUs.
Faltan funciones o métricas
Si no puede ver algunas funciones o métricas, lo más probable es que esté usando una versión libtpu obsoleta. Las funciones y métricas de la biblioteca de monitorización de TPUs se incluyen en las versiones de libtpu, y es posible que las versiones obsoletas no tengan las nuevas funciones y métricas.
Comprueba la versión de libtpu que se está ejecutando en tu entorno:
Línea de comandos:
pip show libtpu
Python:
import libtpu
print(libtpu.__version__)
Si no usas la versión más reciente de
libtpu, usa el siguiente comando para actualizar la biblioteca:
pip install --upgrade libtpu