Biblioteca de supervisión de TPU

Obtén estadísticas detalladas sobre el rendimiento y el comportamiento de tu hardware de Cloud TPU con las capacidades avanzadas de supervisión de TPU, creadas directamente sobre la capa de software fundamental, LibTPU. Si bien LibTPU abarca controladores, bibliotecas de redes, el compilador XLA y el entorno de ejecución de TPU para interactuar con las TPU, este documento se centra en la biblioteca de supervisión de TPU.

La biblioteca de supervisión de TPU proporciona lo siguiente:

  • Observabilidad integral: Obtén acceso a la API de telemetría y al conjunto de métricas, que proporcionan estadísticas detalladas sobre el rendimiento operativo y los comportamientos específicos de tus TPU.

  • Kits de herramientas de diagnóstico: Proporcionan un SDK y una interfaz de línea de comandos (CLI) diseñados para permitir la depuración y el análisis de rendimiento detallado de tus recursos de TPU.

Estas funciones de supervisión están diseñadas para ser una solución de alto nivel orientada al cliente, que te proporciona las herramientas esenciales para optimizar tus cargas de trabajo de TPU de manera eficaz.

La biblioteca de supervisión de TPU te brinda información detallada sobre el rendimiento de las cargas de trabajo de aprendizaje automático en el hardware de TPU. Está diseñado para ayudarte a comprender el uso de la TPU, identificar cuellos de botella y depurar problemas de rendimiento. Te brinda información más detallada que las métricas de interrupción, las métricas de buen rendimiento y otras métricas.

Comienza a usar la biblioteca de supervisión de TPU

Acceder a estas estadísticas potentes es sencillo. La funcionalidad de supervisión de TPU está integrada en el SDK de LibTPU, por lo que se incluye cuando instalas LibTPU.

Instala LibTPU

pip install libtpu

Como alternativa, las actualizaciones de LibTPU se coordinan con los lanzamientos de JAX, lo que significa que, cuando instalas el lanzamiento de JAX más reciente (que se lanza mensualmente), por lo general, se fijará en la versión de LibTPU compatible más reciente y sus funciones.

Instala JAX

pip install -U "jax[tpu]"

Para los usuarios de PyTorch, instalar PyTorch/XLA proporciona la funcionalidad de supervisión de LibTPU y TPU más reciente.

Instala PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Para obtener más información sobre cómo instalar PyTorch/XLA, consulta Instalación en el repositorio de GitHub de PyTorch/XLA.

Importa la biblioteca en Python

Para comenzar a usar la biblioteca de supervisión de TPU, debes importar el módulo libtpu en tu código de Python.

from libtpu.sdk import tpumonitoring

Enumera todas las funciones compatibles

Enumera todos los nombres de las métricas y la funcionalidad que admiten:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

Métricas admitidas

En el siguiente muestra de código, se muestra cómo enumerar todos los nombres de métricas admitidos:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

En la siguiente tabla, se muestran todas las métricas y sus definiciones correspondientes:

Métrica Definición Nombre de la métrica para la API Valores de ejemplo
Uso de Tensor Core Mide el porcentaje de uso de TensorCore, calculado como el porcentaje de operaciones que forman parte de las operaciones de TensorCore. Se hizo un muestreo de 10 microsegundos cada 1 segundo. No puedes modificar la tasa de muestreo. Esta métrica te permite supervisar la eficiencia de tus cargas de trabajo en dispositivos TPU. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# Porcentaje de uso para los IDs de acelerador del 0 al 3.
Porcentaje del ciclo de trabajo Porcentaje de tiempo durante el último período de muestra (cada 5 segundos; se puede ajustar configurando la marca LIBTPU_INIT_ARG) durante el cual el acelerador se procesaba de forma activa (registrado con los ciclos que se usaron para ejecutar programas HLO durante el último período de muestreo). Esta métrica representa qué tan ocupada está una TPU y se emite por chip. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Porcentaje del ciclo de trabajo para los IDs de acelerador del 0 al 3.
Capacidad total de HBM Esta métrica informa la capacidad total de HBM en bytes. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Capacidad total de HBM en bytes que se adjunta a los IDs de acelerador del 0 al 3.
Uso de la capacidad de HBM Esta métrica informa el uso de la capacidad de HBM en bytes durante el último período de muestreo (cada 5 segundos; se puede ajustar configurando la marca LIBTPU_INIT_ARG). hbm_capacity_usage ['100', '200', '300', '400']

# Uso de capacidad para HBM en bytes que se adjuntan a los IDs de acelerador del 0 al 3.
Latencia de transferencia de búfer Latencias de transferencia de red para el tráfico de varias porciones a gran escala. Esta visualización te permite comprender el entorno general de rendimiento de la red. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# Tamaño del búfer, media, p50, p90, p99 y p99.9 de la distribución de latencia de transferencia de red.
Métricas de distribución del tiempo de ejecución de operaciones de alto nivel Proporciona estadísticas detalladas del rendimiento sobre el estado de ejecución del objeto binario compilado de HLO, lo que permite la detección de regresiones y la depuración a nivel del modelo. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# Es la distribución de la duración del tiempo de ejecución del HLO para CoreType-CoreID con la media, p50, p90, p95 y p999.
Tamaño de la cola del optimizador de alto nivel El monitoreo del tamaño de la cola de ejecución de HLO hace un seguimiento de la cantidad de programas HLO compilados que están en espera o en ejecución. Esta métrica revela la congestión de la canalización de ejecución, lo que permite identificar los cuellos de botella de rendimiento en la ejecución del hardware, la sobrecarga del controlador o la asignación de recursos. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Mide el tamaño de la cola para CoreType-CoreID.
Latencia colectiva de extremo a extremo Esta métrica mide la latencia colectiva de extremo a extremo en la DCN en microsegundos, desde el host que inicia la operación hasta todos los pares que reciben el resultado. Incluye la reducción de datos del host y el envío de resultados a la TPU. Los resultados son cadenas que detallan el tamaño del búfer, el tipo y las latencias medias, p50, p90, p95 y p99.9. collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency.
Latencia de ida y vuelta en la capa de transporte Distribución de los tiempos de ida y vuelta (RTT) mínimos observados en las conexiones TCP que usa gRPC para el tráfico de TPU de múltiples segmentos. grpc_tcp_min_round_trip_times ['27.63, 29.03, 38.52, 41.63, 52.74']

# Representa la media, el percentil 50, el percentil 90, el percentil 95 y el percentil 99.9 de la distribución en microsegundos (µs).
Capacidad de procesamiento en la capa de transporte Es la distribución acumulativa de la capacidad de procesamiento reciente de las conexiones TCP que usa gRPC para el tráfico de TPU de varias porciones. grpc_tcp_delivery_rates ['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']

# Representa la media, el percentil 50, el percentil 90, el percentil 95 y el percentil 99.9 de la distribución en microsegundos (µs).

Lee datos de métricas

Para leer los datos de métricas, especifica el nombre de la métrica cuando llames a la función tpumonitoring.get_metric. Puedes insertar verificaciones de métricas ad hoc en el código de bajo rendimiento para identificar si los problemas de rendimiento provienen del software o del hardware.

En el siguiente muestra de código, se muestra cómo leer la métrica duty_cycle:

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

Usa métricas para verificar la utilización de la TPU

En los siguientes ejemplos, se muestra cómo usar las métricas de la biblioteca de supervisión de TPU para hacer un seguimiento del uso de la TPU.

Supervisa el ciclo de trabajo de la TPU durante el entrenamiento de JAX

Situación: Estás ejecutando una secuencia de comandos de entrenamiento de JAX y deseas supervisar la métrica duty_cycle_pct de la TPU durante todo el proceso de entrenamiento para confirmar que las TPU se utilizan de manera eficaz. Puedes registrar esta métrica periódicamente durante el entrenamiento para hacer un seguimiento de la utilización de la TPU.

En el siguiente muestra de código, se muestra cómo supervisar el ciclo de trabajo de la TPU durante el entrenamiento de JAX:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

Verifica la utilización de la HBM antes de ejecutar la inferencia de JAX

Situación: Antes de ejecutar la inferencia con tu modelo de JAX, verifica el uso actual de la HBM (memoria de gran ancho de banda) en la TPU para confirmar que tienes suficiente memoria disponible y obtener una medición de referencia antes de que comience la inferencia.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

Cómo consultar las métricas de red

Situación: Estás ejecutando una carga de trabajo multihost y de Multislice, y quieres conectarte a uno de los Pods o las TPU de GKE con SSH para ver las métricas de red mientras se ejecuta la carga de trabajo. Los comandos también se pueden incorporar directamente en la carga de trabajo de varios hosts.

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

 # --- Your JAX model and training setup goes here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)
     # --- details here ---

    # ==============================================================================
    # Metric 1: TCP Delivery Rate
    # ==============================================================================
    # This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)

        # Get the metric object
        delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)

        # Print the description provided by the library
        print("Description:", delivery_rate_metric.description())

        # Print the actual data payload
        print("Data:", delivery_rate_metric.data())

    # ==============================================================================
    # Metric 2: TCP Minimum Round Trip Time (RTT)
    # ==============================================================================
    # This metric reflects the minimum RTT measured between sending a TCP packet
    # and receiving the acknowledgement.
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)

        # Get the metric object
        min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)

        # Print the description provided by the library
        print("Description:", min_rtt_metric.description())

        # Print the actual data payload
        print("Data:", min_rtt_metric.data())

Frecuencia de actualización de las métricas de TPU

La frecuencia de actualización de las métricas de TPU se limita a un mínimo de un segundo. Los datos de las métricas del host se exportan con una frecuencia fija de 1 Hz. La latencia que introduce este proceso de exportación es insignificante. Las métricas de tiempo de ejecución de LibTPU no están sujetas a la misma restricción de frecuencia. Sin embargo, para mantener la coherencia, estas métricas también se muestrean a 1 Hz o 1 muestra por segundo.

Módulo de TPU-Z

TPU-Z es una herramienta de telemetría y depuración para las TPU. Proporciona información detallada sobre el estado del tiempo de ejecución de todos los núcleos de TPU conectados a un host. La funcionalidad se proporciona a través del módulo tpuz, que forma parte del módulo libtpu.sdk en el SDK de libtpu de Python. El módulo proporciona una instantánea del estado de cada núcleo.

El caso de uso principal de TPU-Z es diagnosticar bloqueos o interbloqueos en cargas de trabajo de TPU distribuidas. Puedes consultar el servicio TPU-Z en los hosts para capturar el estado de cada núcleo, comparar los contadores de programa, las ubicaciones de HLO y los IDs de ejecución en todos los núcleos para identificar anomalías.

Usa la función get_core_state_summary() dentro de la biblioteca libtpu.sdk para mostrar las métricas de tpu-z:

summary = sdk.tpuz.get_core_state_summary()

El resultado de las métricas de TPU-Z se proporciona como un diccionario. A continuación, se muestra un ejemplo truncado para un solo núcleo:

{
  "host_name": "my-tpu-host-vm",
  "core_states": {
    "1": {
      "core_id": {
        "global_core_id": 1,
        "chip_id": 0,
        "core_on_chip": {
          "type": "TPU_CORE_TYPE_TENSOR_CORE",
          "index": 1
        }
      },
      "sequencer_info": [
        {
          "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
          "sequencer_index": 0,
          "pc": 4490,
          "program_id": 3274167277388825310,
          "run_id": 3
        }
      ],
      "program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
      "queued_program_info": [],
      "error_message": ""
    }
    // ...
  }
}

Para recuperar información sobre los optimizadores de alto nivel (HLO) en cada núcleo, establece el parámetro include_hlo_info en True:

summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)

El resultado incluye información adicional sobre HLO:

"1": {
    "core_id": {
      "global_core_id": 1,
      "chip_id": 0,
      "core_on_chip": {
        "type": "TPU_CORE_TYPE_TENSOR_CORE",
        "index": 1
      }
    },
    "sequencer_info": [
      {
        "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
        "sequencer_index": 0,
        "pc": 17776,
        "tag": 3,
        "tracemark": 2147483646,
        "program_id": 3230481660274331500,
        "run_id": 81,
        "hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
        "hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
      }
    ],
    "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
    "launch_id": 1394130914,
    "queued_program_info": [
      {
        "run_id": 81,
        "launch_id": 1394130914,
        "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
      }
    ]
  }

Métricas de tpu-z

La función get_core_state_summary devuelve métricas de TPU-Z en forma de un diccionario con la siguiente estructura.

CurrentCoreStateSummary

El diccionario CurrentCoreStateSummary proporciona un resumen detallado del estado de un núcleo de TPU individual.

Campo Tipo Descripción
core_id dictionary Es un diccionario TpuCoreIdentifier que contiene información del ID sobre el núcleo de la TPU.
sequencer_info Lista de diccionarios Es una lista de diccionarios SequencerInfo que describen el estado de cada secuenciador en el núcleo.
program_fingerprint bytes Huella digital del programa que se ejecuta en este núcleo.
launch_id integer Es el ID de lanzamiento del programa actual o más reciente.
queued_program_info Lista de diccionarios Es una lista de diccionarios QueuedProgramInfo para los programas en cola de ejecución.
error_message cadena Son los mensajes de error de este núcleo.

TpuCoreIdentifier

El diccionario TpuCoreIdentifier proporciona información de ID para los núcleos dentro del sistema de TPU.

Campo Tipo Descripción
global_core_id integer Es el ID del núcleo.
chip_id integer Es el ID del chip al que pertenece el núcleo.
core_on_chip dictionary Es un diccionario TpuCoreOnChip que describe el tipo del núcleo y su índice en el chip.

TpuCoreOnChip

El diccionario TpuCoreOnChip contiene información sobre las propiedades de un núcleo dentro de un chip específico.

Campo Tipo Descripción
type cadena Es el tipo de núcleo de TPU. Por ejemplo: TPU_CORE_TYPE_TENSOR_CORE.
index integer Índice del núcleo en el chip.

SequencerInfo

El diccionario SequencerInfo contiene información sobre el estado de un solo secuenciador en un núcleo.

Campo Tipo Descripción
sequencer_type cadena Es el tipo de secuenciador. Por ejemplo: TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER.
sequencer_index integer Índice del secuenciador (si hay varios del mismo tipo).
pc integer Es el valor actual del contador de programa.
program_id integer Es el ID asociado a una instancia específica de un programa que se inicia para su ejecución en un núcleo de TPU.
run_id integer Es el ID de ejecución asociado a una instancia específica de la ejecución de un programa en un núcleo de TPU.
hlo_location cadena Es la información de ubicación del optimizador de alto nivel.
hlo_detailed_info cadena Es información detallada sobre el optimizador de alto nivel.

QueuedProgramInfo

El diccionario QueuedProgramInfo contiene información sobre los programas en cola para su ejecución en un núcleo.

Campo Tipo Descripción
run_id integer Es el ID de ejecución del programa en cola.
launch_id integer Es el ID de lanzamiento del programa en la cola.
program_fingerprint bytes Es la huella dactilar del programa en cola.

TPU-Z con JAX

Puedes acceder a las métricas de tpu-z en las cargas de trabajo de JAX a través de la biblioteca libtpu.sdk. La siguiente secuencia de comandos de Python usa JAX para el procesamiento de tensores de alto rendimiento y, al mismo tiempo, usa el SDK de libtpu en un subproceso en segundo plano para supervisar el estado y la actividad del hardware de TPU subyacente.

Incluye los siguientes paquetes de Python:

import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk

La función monitor_tpu_status usa un subproceso en segundo plano para mostrar de forma continua el estado operativo de los núcleos de las TPU mientras la aplicación principal ejecuta una carga de trabajo de JAX. Actúa como una herramienta de diagnóstico en tiempo real.

def monitor_tpu_status():
  """Monitors TPU status in a background thread."""

  while monitoring_active:
    try:
      summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
      if summary and 'core_states' in summary:
        print(summary)
      else:
        print('WARNING: Call returned an empty or invalid summary.')
    except RuntimeError as e:
      print(f'FAIL: Error calling API: {e}')
    except Exception as e:
      print(f'FAIL: Unexpected error in monitor thread: {e}')

    for _ in range(MONITORING_INTERVAL_SECONDS * 2):
      if not monitoring_active:
        break
      time.sleep(0.5)
  print('✅ Monitoring thread stopped.')

La función transformer_block implementa una capa completa de la arquitectura Transformer, que es el componente fundamental de los LLM.

@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
  """A simplified but computationally intensive Transformer block."""
  # Multi-head Self-Attention
  qkv = jnp.dot(x, params['qkv_kernel'])
  q, k, v = jnp.array_split(qkv, 3, axis=-1)

  # Reshape for multi-head attention
  q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)

  # Scaled dot-product attention
  attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
  attention_weights = jax.nn.softmax(attention_scores, axis=-1)
  attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
  attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)

  attention_output = jnp.dot(attention_output, params['o_kernel'])

  # Residual connection and Layer Normalization 1
  h1 = x + attention_output
  h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
  h1_norm = h1_norm / jnp.sqrt(
      jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
  )

  # Feed-Forward Network
  ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
  ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])

  # Residual connection and Layer Normalization 2
  h2 = h1_norm + ffn_output
  h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
  h2_norm = h2_norm / jnp.sqrt(
      jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
  )

  return h2_norm

La función main coordina la configuración del cálculo de JAX, inicia la supervisión en segundo plano de la TPU y ejecuta el bucle principal de la carga de trabajo.

def main():
  num_devices = jax.device_count()
  print(f"Running on {num_devices} devices.")

  batch_size = 128 * num_devices
  seq_len = 512
  embed_dim = 1024
  ffn_dim = embed_dim * 4

  key = jax.random.PRNGKey(0)

  params = {
      'qkv_kernel': jax.random.normal(
          key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
      ),
      'o_kernel': jax.random.normal(
          key, (embed_dim, embed_dim), dtype=jnp.bfloat16
      ),
      'ffn1_kernel': jax.random.normal(
          key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
      ),
      'ffn2_kernel': jax.random.normal(
          key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
      ),
  }
  input_data = jax.random.normal(
      key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
  )
  input_data = jax.device_put(input_data)
  monitor_thread = threading.Thread(target=monitor_tpu_status)
  monitor_thread.start()
  print("Starting JAX computation loop...")
  start_time = time.time()
  iterations = 0
  while time.time() - start_time < JOB_DURATION_SECONDS:
    result = transformer_block(params, input_data)
    result.block_until_ready()
    iterations += 1
    print(f'  -> Jax iteration {iterations} complete.', end='\r')

  print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")

  global monitoring_active
  monitoring_active = False
  monitor_thread.join()

if __name__ == '__main__':
  main()

Soluciona problemas

En esta sección, se proporciona información para solucionar problemas que te ayudará a identificar y resolver problemas que podrías encontrar mientras usas la biblioteca de Monitoring de TPU.

Faltan funciones o métricas

Si no puedes ver algunas funciones o métricas, la causa más común es una versión libtpu desactualizada. Las funciones y métricas de la biblioteca de supervisión de TPU se incluyen en las versiones de libtpu, y es posible que las versiones desactualizadas no incluyan funciones y métricas nuevas.

Verifica la versión de libtpu que se ejecuta en tu entorno:

Línea de comandos:

pip show libtpu

Python:

import libtpu

print(libtpu.__version__)

Si no usas la versión más reciente de libtpu, usa el siguiente comando para actualizar la biblioteca:

pip install --upgrade libtpu