Biblioteca de supervisión de TPU
Obtén estadísticas detalladas sobre el rendimiento y el comportamiento del hardware de Cloud TPU con las capacidades avanzadas de supervisión de TPU, creadas directamente sobre la capa básica de software, LibTPU. Si bien este último incluye controladores, bibliotecas de redes, el compilador XLA y el entorno de ejecución de TPU para interactuar con las TPUs, este documento se centra en su biblioteca de supervisión.
La biblioteca de supervisión de TPU proporciona lo siguiente:
Observabilidad integral: Obtén acceso a la API de telemetría y al conjunto de métricas, que proporcionan estadísticas detalladas sobre el rendimiento operativo y el comportamiento específico de tus TPUs.
Kits de herramientas de diagnóstico: Proporcionan un SDK y una interfaz de línea de comandos (CLI) diseñados para permitir la depuración y el análisis de rendimiento detallado de tus recursos de TPU.
Estas funciones de supervisión están diseñadas con el objetivo de brindar una solución de alto nivel orientada al cliente, que te proporciona las herramientas esenciales para optimizar tus cargas de trabajo de TPU con eficacia.
La biblioteca de supervisión de TPU te brinda información detallada sobre el rendimiento de las cargas de trabajo de aprendizaje automático en el hardware de TPU. Está diseñada para ayudarte a comprender el uso de la TPU, identificar cuellos de botella y depurar problemas de rendimiento. Además, te brinda información más detallada que las métricas de interrupción, de buen rendimiento, entre otras.
Empieza a usar la biblioteca de supervisión de TPU
Acceder a estas estadísticas valiosas es sencillo. La funcionalidad de supervisión de TPU está integrada en el SDK de LibTPU, por lo que se incluye cuando se instala.
Instala LibTPU
pip install libtpu
Como alternativa, las actualizaciones de LibTPU se coordinan con los lanzamientos de JAX. Esto significa que, cuando instalas el lanzamiento de JAX más reciente (por mes), por lo general, se fijará en la versión de LibTPU compatible más reciente y sus funciones.
Instala JAX
pip install -U "jax[tpu]"
Para los usuarios de PyTorch, instalar PyTorch/XLA proporciona la funcionalidad de supervisión de LibTPU y TPU más reciente.
Instala PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Si deseas obtener más información para instalar PyTorch/XLA, consulta Instalación en el repositorio de GitHub de PyTorch/XLA.
Importa la biblioteca en Python
Para empezar a usar la biblioteca de supervisión de TPU, debes importar el módulo libtpu en tu código de Python.
from libtpu.sdk import tpumonitoring
Enumera todas las funciones compatibles
Enumera todos los nombres de las métricas y la funcionalidad que se admiten:
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
Métricas admitidas
En la siguiente muestra de código, se muestra cómo enumerar todos los nombres de métricas admitidas:
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
En la siguiente tabla, se muestran todas las métricas y sus definiciones correspondientes:
| Métrica | Definición | Nombre de la métrica para la API | Valores de ejemplo |
|---|---|---|---|
| Uso de Tensor Core | Mide el porcentaje de uso de TensorCore, calculado como el porcentaje de las operaciones que forman parte de TensorCore. Se hizo un muestreo de 10 microsegundos cada 1 segundo. No puedes modificar la tasa de muestreo. Esta métrica te permite supervisar la eficiencia de tus cargas de trabajo en dispositivos TPU. |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# porcentaje de uso para los IDs de acelerador del 0 al 3. |
| Porcentaje del ciclo de trabajo | Es el porcentaje de tiempo durante el último período de muestra (cada 5 segundos. Se puede ajustar configurando la marca LIBTPU_INIT_ARG) durante el que el acelerador se procesaba de forma activa (registrado con los ciclos que se usaron para ejecutar programas HLO durante el último período de muestreo). Esta métrica representa qué tan ocupada está una TPU y se emite por chip.
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# Porcentaje del ciclo de trabajo para los IDs de acelerador del 0 al 3. |
| Capacidad total de HBM | Esta métrica informa la capacidad total de HBM en bytes. |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# Capacidad total de HBM en bytes que se adjunta a los IDs de acelerador del 0 al 3. |
| Uso de la capacidad de HBM | Esta métrica informa el uso de la capacidad de HBM en bytes durante el último período de muestreo (cada 5 segundos; se puede ajustar configurando la marca LIBTPU_INIT_ARG).
|
hbm_capacity_usage
|
['100', '200', '300', '400']
# Uso de capacidad para HBM en bytes que se adjunta a los IDs de acelerador del 0 al 3. |
| Latencia de transferencia de búfer | Latencias de transferencia de red para el tráfico de varias porciones a gran escala. Esta visualización te permite comprender el entorno general de rendimiento de la red. |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# tamaño del búfer, media, p50, p90, p99 y p99.9 de la distribución de latencia de transferencia de red. |
| Métricas de distribución del tiempo de ejecución de operaciones de alto nivel | Proporciona estadísticas detalladas de rendimiento sobre el estado de ejecución del objeto binario compilado de HLO, lo que permite la detección de regresiones y la depuración a nivel del modelo. |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# La distribución de la duración del tiempo de ejecución del HLO para CoreType-CoreID con la media, p50, p90, p95 y p999. |
| Tamaño de la cola del optimizador de alto nivel | La supervisión del tamaño de la cola de ejecución de HLO hace un seguimiento de la cantidad de programas HLO compilados que están en espera o en ejecución. Esta métrica revela la congestión de la canalización de ejecución, lo que permite identificar los cuellos de botella de rendimiento en la ejecución del hardware, la sobrecarga del controlador o la asignación de los recursos. |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# Mide el tamaño de la cola para CoreType-CoreID. |
| Latencia colectiva de extremo a extremo | Esta métrica mide la latencia colectiva de extremo a extremo de la DCN en microsegundos, desde el host que inicia la operación hasta todos los pares que reciben el resultado. Incluye la reducción de datos del host y el envío de resultados a la TPU. Los resultados son cadenas que detallan el tamaño del búfer, el tipo y las latencias medias, p50, p90, p95 y p99.9. |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# Tamaño de la transferencia por operación colectiva y métricas de latencia de extremo a extremo, media, p50, p90, p95, p999. |
| Latencia de ida y vuelta en la capa de transporte | Es la distribución de los tiempos de ida y vuelta (RTT) mínimos observados en las conexiones TCP que usa gRPC para el tráfico de TPU de varias porciones. |
grpc_tcp_min_round_trip_times
|
['27.63, 29.03, 38.52, 41.63, 52.74']
# Representa la media, los percentiles p50, p90, p95 y p99.9 de la distribución en microsegundos (µs). |
| Capacidad de procesamiento en la capa de transporte | Es la distribución acumulativa de la capacidad de procesamiento reciente de las conexiones TCP que usa gRPC para el tráfico de TPU de varias porciones. |
grpc_tcp_delivery_rates
|
['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']
# Representa la media, los percentiles p50, p90, p95 y p99.9 de la distribución en microsegundos (µs). |
Lee datos de métricas
Para leer los datos de métricas, especifica el nombre de la métrica cuando llames a la función tpumonitoring.get_metric. Puedes insertar comprobaciones de métricas ad hoc en código de bajo rendimiento para identificar si los problemas de rendimiento provienen del software o del hardware.
En la siguiente muestra de código, se explica cómo leer la métrica duty_cycle:
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
Usa métricas para comprobar el uso de la TPU
En los siguientes ejemplos, se muestra cómo usar las métricas de la biblioteca de supervisión de TPU para hacer un seguimiento del uso de la TPU.
Supervisa el ciclo de trabajo de la TPU durante el entrenamiento de JAX
Situación: Estás ejecutando una secuencia de comandos de entrenamiento de JAX y deseas supervisar la métrica duty_cycle_pct de la TPU durante todo el proceso de entrenamiento para confirmar que las TPUs se usan con eficacia. Puedes registrar esta métrica periódicamente durante el entrenamiento para hacer un seguimiento del uso de la TPU.
En la siguiente muestra de código, se explica cómo supervisar el ciclo de trabajo de la TPU durante el entrenamiento de JAX:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
Comprueba el uso de la HBM antes de ejecutar la inferencia de JAX
Situación: Antes de ejecutar la inferencia con tu modelo de JAX, comprueba el uso actual de la HBM (memoria de ancho de banda elevado) en la TPU para confirmar que tienes suficiente memoria disponible y obtener una medición de referencia antes de que empiece la inferencia.
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
Consulta las métricas de red
Situación: Estás ejecutando una carga de trabajo de varios hosts y de varias porciones, y quieres conectarte a uno de los Pods o las TPUs de GKE con SSH para ver las métricas de red mientras se ejecuta la carga de trabajo. Los comandos también se pueden incorporar directamente en la carga de trabajo de varios hosts.
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and training setup goes here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
# --- details here ---
# ==============================================================================
# Metric 1: TCP Delivery Rate
# ==============================================================================
# This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)
# Get the metric object
delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)
# Print the description provided by the library
print("Description:", delivery_rate_metric.description())
# Print the actual data payload
print("Data:", delivery_rate_metric.data())
# ==============================================================================
# Metric 2: TCP Minimum Round Trip Time (RTT)
# ==============================================================================
# This metric reflects the minimum RTT measured between sending a TCP packet
# and receiving the acknowledgement.
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)
# Get the metric object
min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)
# Print the description provided by the library
print("Description:", min_rtt_metric.description())
# Print the actual data payload
print("Data:", min_rtt_metric.data())
Frecuencia de actualización de las métricas de TPU
La frecuencia de actualización de las métricas de TPU se limita a un mínimo de un segundo. Los datos de las métricas del host se exportan con una frecuencia fija de 1 Hz, y la latencia que introduce este proceso de exportación es insignificante. Las métricas del entorno de ejecución de LibTPU no están sujetas a la misma restricción de frecuencia. Sin embargo, para mantener la coherencia, estas métricas también se muestrean a 1 Hz o 1 muestra por segundo.
Módulo de TPU-Z
TPU-Z es una herramienta de telemetría y depuración para las TPUs. Proporciona información detallada sobre el estado del entorno de ejecución de todos los núcleos de TPU conectados a un host. La funcionalidad se proporciona por medio del módulo tpuz, que forma parte del módulo libtpu.sdk en el SDK de libtpu de Python. El módulo proporciona una instantánea del estado de cada núcleo.
El caso de uso principal de TPU-Z es diagnosticar bloqueos o interbloqueos en cargas de trabajo de TPU distribuidas. Puedes consultar el servicio TPU-Z en los hosts para capturar el estado de cada núcleo, comparar los contadores de programa, las ubicaciones de HLO y los IDs de ejecución en todos los núcleos para identificar anomalías.
Usa la función get_core_state_summary() en la biblioteca libtpu.sdk para mostrar las métricas de TPU-Z:
summary = sdk.tpuz.get_core_state_summary()
El resultado de las métricas de TPU-Z se proporciona como un diccionario. A continuación, se muestra un ejemplo truncado para un solo núcleo:
{
"host_name": "my-tpu-host-vm",
"core_states": {
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 4490,
"program_id": 3274167277388825310,
"run_id": 3
}
],
"program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
"queued_program_info": [],
"error_message": ""
}
// ...
}
}
Para recuperar información sobre los optimizadores de alto nivel (HLO) en cada núcleo, establece el parámetro include_hlo_info en True:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
El resultado incluye información adicional sobre HLO:
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 17776,
"tag": 3,
"tracemark": 2147483646,
"program_id": 3230481660274331500,
"run_id": 81,
"hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
"hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
}
],
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
"launch_id": 1394130914,
"queued_program_info": [
{
"run_id": 81,
"launch_id": 1394130914,
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
}
]
}
Métricas de TPU-Z
La función get_core_state_summary muestra las métricas de TPU-Z en forma de un diccionario con la siguiente estructura.
CurrentCoreStateSummary
El diccionario CurrentCoreStateSummary proporciona un resumen detallado del estado de un núcleo de TPU individual.
| Campo | Tipo | Descripción |
|---|---|---|
core_id |
dictionary | Es un diccionario TpuCoreIdentifier que contiene información del ID sobre el núcleo de la TPU. |
sequencer_info |
list of dictionaries | Es una lista de diccionarios SequencerInfo en la que se describe el estado de cada secuenciador en el núcleo. |
program_fingerprint |
bytes | Es la huella digital del programa que se ejecuta en este núcleo. |
launch_id |
integer | Es el ID de la ejecución del programa actual o del más reciente. |
queued_program_info |
list of dictionaries | Es una lista de diccionarios QueuedProgramInfo para los programas en cola de ejecución. |
error_message |
string | Son los mensajes de error de este núcleo. |
TpuCoreIdentifier
El diccionario TpuCoreIdentifier proporciona información del ID para los núcleos en el sistema de TPU.
| Campo | Tipo | Descripción |
|---|---|---|
global_core_id |
integer | Es el ID del núcleo. |
chip_id |
integer | Es el ID del chip al que pertenece el núcleo. |
core_on_chip |
dictionary | Es un diccionario TpuCoreOnChip que describe el tipo del núcleo y su índice en el chip. |
TpuCoreOnChip
El diccionario TpuCoreOnChip contiene información sobre las propiedades de un núcleo en un chip específico.
| Campo | Tipo | Descripción |
|---|---|---|
type |
string | Es el tipo de núcleo de TPU. Por ejemplo: TPU_CORE_TYPE_TENSOR_CORE. |
index |
integer | Es el índice del núcleo en el chip. |
SequencerInfo
El diccionario SequencerInfo contiene información sobre el estado de un solo secuenciador en un núcleo.
| Campo | Tipo | Descripción |
|---|---|---|
sequencer_type |
string | Es el tipo de secuenciador. Por ejemplo: TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER. |
sequencer_index |
integer | Es el índice del secuenciador (si hay varios del mismo tipo). |
pc |
integer | Es el valor actual del contador de programa. |
program_id |
integer | Es el ID asociado a una instancia específica de un programa que se inicia para su ejecución en un núcleo de TPU. |
run_id |
integer | Es el ID de ejecución asociado a una instancia específica de la ejecución de un programa en un núcleo de TPU. |
hlo_location |
string | Es la información de ubicación del optimizador de alto nivel. |
hlo_detailed_info |
string | Es información detallada sobre el optimizador de alto nivel. |
QueuedProgramInfo
El diccionario QueuedProgramInfo contiene información sobre los programas en cola para su ejecución en un núcleo.
| Campo | Tipo | Descripción |
|---|---|---|
run_id |
integer | Es el ID de ejecución del programa en cola. |
launch_id |
integer | Es el ID de lanzamiento del programa en cola. |
program_fingerprint |
bytes | Es la huella digital del programa en cola. |
TPU-Z con JAX
Puedes acceder a las métricas de TPU-Z en las cargas de trabajo de JAX con la biblioteca libtpu.sdk.
La siguiente secuencia de comandos de Python usa JAX para el procesamiento de tensores de alto rendimiento y, al mismo tiempo, usa el SDK de libtpu en un subproceso en segundo plano. Todo esto con el objetivo de supervisar el estado y la actividad del hardware de TPU subyacente.
Incluye los siguientes paquetes de Python:
import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk
La función monitor_tpu_status usa un subproceso en segundo plano para mostrar de forma continua el estado operativo de los núcleos de las TPUs mientras la aplicación principal ejecuta una carga de trabajo de JAX. Actúa como una herramienta de diagnóstico en tiempo real.
def monitor_tpu_status():
"""Monitors TPU status in a background thread."""
while monitoring_active:
try:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
if summary and 'core_states' in summary:
print(summary)
else:
print('WARNING: Call returned an empty or invalid summary.')
except RuntimeError as e:
print(f'FAIL: Error calling API: {e}')
except Exception as e:
print(f'FAIL: Unexpected error in monitor thread: {e}')
for _ in range(MONITORING_INTERVAL_SECONDS * 2):
if not monitoring_active:
break
time.sleep(0.5)
print('✅ Monitoring thread stopped.')
La función transformer_block implementa una capa completa de la arquitectura Transformer, que es el componente básico de los LLM.
@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
"""A simplified but computationally intensive Transformer block."""
# Multi-head Self-Attention
qkv = jnp.dot(x, params['qkv_kernel'])
q, k, v = jnp.array_split(qkv, 3, axis=-1)
# Reshape for multi-head attention
q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
# Scaled dot-product attention
attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)
attention_output = jnp.dot(attention_output, params['o_kernel'])
# Residual connection and Layer Normalization 1
h1 = x + attention_output
h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
h1_norm = h1_norm / jnp.sqrt(
jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
)
# Feed-Forward Network
ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])
# Residual connection and Layer Normalization 2
h2 = h1_norm + ffn_output
h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
h2_norm = h2_norm / jnp.sqrt(
jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
)
return h2_norm
La función main coordina la configuración del cálculo de JAX, inicia la supervisión en segundo plano de la TPU y ejecuta el bucle principal de la carga de trabajo.
def main():
num_devices = jax.device_count()
print(f"Running on {num_devices} devices.")
batch_size = 128 * num_devices
seq_len = 512
embed_dim = 1024
ffn_dim = embed_dim * 4
key = jax.random.PRNGKey(0)
params = {
'qkv_kernel': jax.random.normal(
key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
),
'o_kernel': jax.random.normal(
key, (embed_dim, embed_dim), dtype=jnp.bfloat16
),
'ffn1_kernel': jax.random.normal(
key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
),
'ffn2_kernel': jax.random.normal(
key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
),
}
input_data = jax.random.normal(
key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
)
input_data = jax.device_put(input_data)
monitor_thread = threading.Thread(target=monitor_tpu_status)
monitor_thread.start()
print("Starting JAX computation loop...")
start_time = time.time()
iterations = 0
while time.time() - start_time < JOB_DURATION_SECONDS:
result = transformer_block(params, input_data)
result.block_until_ready()
iterations += 1
print(f' -> Jax iteration {iterations} complete.', end='\r')
print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")
global monitoring_active
monitoring_active = False
monitor_thread.join()
if __name__ == '__main__':
main()
Soluciona problemas
En esta sección, se proporciona información para solucionar problemas que te ayudará a identificar y resolver problemas que podrías encontrar mientras usas la biblioteca de supervisión de TPU.
Faltan funciones o métricas
Si no puedes ver algunas funciones o métricas, la causa más común es una versión libtpu desactualizada. Las funciones y métricas de la biblioteca de supervisión de TPU se incluyen en las versiones de libtpu. Es posible que las versiones desactualizadas no incluyan funciones y métricas nuevas.
Comprueba la versión de libtpu que se ejecuta en tu entorno:
Línea de comandos:
pip show libtpu
Python:
import libtpu
print(libtpu.__version__)
Si no usas la versión más reciente de libtpu, usa el siguiente comando para actualizar la biblioteca:
pip install --upgrade libtpu