Bibliothèque de surveillance des TPU

Obtenez des insights approfondis sur les performances et le comportement de votre matériel Cloud TPU grâce à des fonctionnalités de surveillance avancées, directement intégrées à la couche logicielle de base, LibTPU. Bien que LibTPU englobe les pilotes, les bibliothèques réseau, le compilateur XLA et l'environnement d'exécution TPU pour l'interaction avec les TPU, ce document ne traite que de la bibliothèque de surveillance des TPU.

La bibliothèque de surveillance des TPU fournit les éléments suivants :

  • Observabilité complète : accédez à l'API de télémétrie et à la suite de métriques. Cela vous permet d'obtenir des insights détaillés sur les performances opérationnelles et les comportements spécifiques de vos TPU.

  • Kits d'outils de diagnostic : un SDK et une interface de ligne de commande (CLI) conçus pour permettre le débogage et l'analyse approfondie des performances de vos ressources TPU.

Ces fonctionnalités de surveillance sont conçues pour offrir une solution de premier niveau aux clients. Elles fournissent les outils essentiels pour optimiser vos charges de travail TPU.

La bibliothèque de surveillance des TPU vous fournit des informations détaillées sur les performances des charges de travail de machine learning sur le matériel TPU. Elle est conçue pour vous aider à comprendre votre utilisation des TPU, à identifier les goulots d'étranglement et à déboguer les problèmes de performances. Elle fournit des informations plus détaillées que les métriques d'interruption, de débit utile et autres.

Premiers pas avec la bibliothèque de surveillance des TPU

Il est facile d'accéder à ces précieux insights. La fonctionnalité de surveillance des TPU est intégrée au SDK LibTPU. Elle est donc incluse lorsque vous installez LibTPU.

Installer LibTPU

pip install libtpu

Les mises à jour de LibTPU sont coordonnées avec les versions de JAX. Cela signifie que lorsque vous installez la dernière version de JAX (publiée chaque mois), vous êtes généralement redirigé vers la dernière version compatible de LibTPU afin de bénéficier de ses fonctionnalités.

Installer JAX

pip install -U "jax[tpu]"

Pour les utilisateurs de PyTorch, l'installation de PyTorch/XLA permet d'obtenir les dernières fonctionnalités de LibTPU et de surveillance des TPU.

Installer PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Pour en savoir plus sur l'installation de PyTorch/XLA, consultez la section Installation dans le dépôt GitHub de PyTorch/XLA.

Importer la bibliothèque dans le code Python

Pour commencer à utiliser la bibliothèque de surveillance des TPU, vous devez importer le module libtpu dans votre code Python.

from libtpu.sdk import tpumonitoring

Lister toutes les fonctionnalités disponibles

Pour lister tous les noms de métriques et les fonctionnalités qu'elles prennent en charge :


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

Métriques acceptées

L'exemple de code suivant montre comment recenser tous les noms de métriques acceptés :

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

Le tableau suivant présente toutes les métriques et les définitions correspondantes :

Métrique Définition Nom de la métrique pour l'API Exemples de valeurs
Utilisation des Tensor Cores Mesure le pourcentage d'utilisation de votre TensorCore (dérivé du pourcentage des opérations faisant partie des opérations TensorCore). Cette valeur est échantillonnée toutes les secondes pendant 10 microsecondes. Vous ne pouvez pas modifier le taux d'échantillonnage. Cette métrique vous permet de surveiller l'efficacité de vos charges de travail sur les appareils TPU. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# pourcentage d'utilisation pour les ID d'accélérateur 0 à 3
Pourcentage du cycle d'utilisation Pourcentage de temps au cours de la dernière période d'échantillonnage (toutes les cinq secondes ; peut être ajusté en définissant l'option LIBTPU_INIT_ARG) pendant lequel l'accélérateur a été en mode de traitement actif (enregistré avec les cycles utilisés pour exécuter les programmes HLO au cours de la dernière période d'échantillonnage). Cette métrique représente la charge d'un TPU. Elle est émise par puce. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# pourcentage de cycle d'utilisation pour les ID d'accélérateur 0 à 3
Capacité totale de la mémoire HBM Cette métrique indique la capacité totale de la mémoire HBM en octets. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Capacité totale de mémoire HBM (en octets) associée aux ID d'accélérateur 0 à 3
Utilisation de la capacité de mémoire HBM Cette métrique indique l'utilisation de la capacité de mémoire HBM (en octets) au cours de la période d'échantillonnage précédente (toutes les cinq secondes ; peut être ajustée en définissant l'option LIBTPU_INIT_ARG). hbm_capacity_usage ['100', '200', '300', '400']

# Utilisation de la capacité de mémoire HBM (en octets) associée aux ID d'accélérateur 0 à 3
Latence de transfert de la mémoire tampon Latences de transfert réseau pour le trafic multitranche à très grande échelle. Cette visualisation vous permet de comprendre l'environnement global des performances réseau. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# taille du tampon, moyenne, centiles (p50, p90, p99 et p99.9) de la distribution de la latence du transfert réseau
Métriques de haut niveau pour la distribution du temps d'exécution des opérations Ces métriques fournissent des insights précis sur les performances de l'état d'exécution du binaire compilé HLO, ce qui permet de détecter les régressions et de déboguer au niveau du modèle. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# Distribution de la durée d'exécution HLO pour CoreType-CoreID avec moyenne, centiles (p50, p90, p99, p99.9)
Taille de la file d'attente de l'optimiseur de haut niveau La surveillance de la taille de la file d'exécution HLO permet de suivre le nombre de programmes HLO compilés en attente ou en cours d'exécution. Cette métrique révèle toute congestion du pipeline d'exécution, ce qui permet d'identifier les goulots d'étranglement entraînant une baisse de performances dans l'exécution matérielle, une éventuelle surcharge du pilote ou un problème d'allocation de ressources. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Mesure la taille de la file d'attente pour CoreType-CoreID.
Latence collective de bout en bout Cette métrique mesure la latence collective de bout en bout sur le DCN en microsecondes, du lancement de l'opération par l'hôte jusqu'à la réception du résultat par tous les pairs. Cela inclut la réduction des données côté hôte et l'envoi de la sortie au TPU. Les résultats sont des chaînes qui détaillent la taille du tampon, le type et les latences moyennes (p50, P90, p95 et p99.9). collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Taille du transfert-opération collective, moyenne, centiles (p50, p90, p95, p99.9) de la latence collective de bout en bout

Lire les données de métrique : mode instantané

Pour activer le mode instantané, spécifiez le nom de la métrique lorsque vous appelez la fonction tpumonitoring.get_metric. Le mode instantané vous permet d'insérer des vérifications de métriques ad hoc dans un code présentant de faibles performances pour déterminer si les problèmes de performances proviennent du logiciel ou du matériel.

L'exemple de code suivant montre comment utiliser le mode instantané pour lire duty_cycle.

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

Accéder aux métriques à l'aide de la CLI

La procédure suivante montre comment interagir avec les métriques LibTPU à l'aide de la CLI :

  1. Installez tpu-info :

    pip install tpu-info
    
    
    # Access help information of tpu-info
    tpu-info --help / -h
    
    
  2. Exécutez la vision par défaut de tpu-info :

    tpu-info
    

    Le résultat ressemble à ce qui suit :

   TPU Chips
   ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
    Chip         Type         Devices  PID       ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
    /dev/accel0  TPU v4 chip  1        130007     /dev/accel1  TPU v4 chip  1        130007     /dev/accel2  TPU v4 chip  1        130007     /dev/accel3  TPU v4 chip  1        130007    └─────────────┴─────────────┴─────────┴────────┘

   TPU Runtime Utilization
   ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
    Device  Memory usage          Duty cycle    ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
    0       0.00 GiB / 31.75 GiB       0.00%     1       0.00 GiB / 31.75 GiB       0.00%     2       0.00 GiB / 31.75 GiB       0.00%     3       0.00 GiB / 31.75 GiB       0.00%    └────────┴──────────────────────┴────────────┘

   TensorCore Utilization
   ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
    Chip ID  TensorCore Utilization    ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
    0                         0.00%     1                         0.00%     3                         0.00%     2                         0.00% |
   └─────────┴────────────────────────┘

   Buffer Transfer Latency
   ┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
    Buffer Size  P50  P90  P95  P999    ┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
          8MB+  | 0us  0us  0us   0us |
   └─────────────┴─────┴─────┴─────┴──────┘

Utiliser des métriques pour vérifier l'utilisation des TPU

Les exemples suivants montrent comment utiliser les métriques de la bibliothèque de surveillance des TPU pour suivre l'utilisation des TPU.

Surveiller le cycle d'utilisation des TPU pendant l'entraînement JAX

Scénario : Vous exécutez un script d'entraînement JAX et souhaitez surveiller la métrique duty_cycle_pct des TPU tout au long du processus d'entraînement pour confirmer que ceux-ci sont utilisés efficacement. Vous pouvez consigner cette métrique périodiquement pendant l'entraînement pour suivre l'utilisation des TPU.

L'exemple de code suivant montre comment surveiller le cycle d'utilisation des TPU pendant l'entraînement JAX :

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

Vérifier l'utilisation de la mémoire HBM avant d'exécuter l'inférence JAX

Scénario : Avant d'exécuter l'inférence avec votre modèle JAX, vérifiez l'utilisation actuelle de la mémoire HBM (High Bandwidth Memory) sur le TPU. Cela permet de vous assurer que vous disposez de suffisamment de mémoire et d'obtenir une mesure de référence avant de lancer l'inférence.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

Fréquence d'exportation des métriques de TPU

La fréquence d'actualisation des métriques de TPU ne peut pas être inférieure à une seconde. Les données de métriques d'hôte sont exportées à une fréquence fixe de 1 Hz. La latence introduite par ce processus d'exportation est négligeable. Les métriques d'exécution de LibTPU ne sont pas soumises à la même contrainte de fréquence. Toutefois, par souci de cohérence, ces métriques sont également échantillonnées à 1 Hz, soit un échantillon par seconde.