Bibliothèque de surveillance des TPU
Obtenez des insights approfondis sur les performances et le comportement de votre matériel Cloud TPU grâce à des fonctionnalités de surveillance avancées, directement intégrées à la couche logicielle de base, LibTPU. Bien que LibTPU englobe les pilotes, les bibliothèques réseau, le compilateur XLA et l'environnement d'exécution TPU pour l'interaction avec les TPU, ce document ne traite que de la bibliothèque de surveillance des TPU.
La bibliothèque de surveillance des TPU fournit les éléments suivants :
Observabilité complète : accédez à l'API de télémétrie et à la suite de métriques. Cela vous permet d'obtenir des insights détaillés sur les performances opérationnelles et les comportements spécifiques de vos TPU.
Kits d'outils de diagnostic : un SDK et une interface de ligne de commande (CLI) conçus pour permettre le débogage et l'analyse approfondie des performances de vos ressources TPU.
Ces fonctionnalités de surveillance sont conçues pour offrir une solution de premier niveau aux clients. Elles fournissent les outils essentiels pour optimiser vos charges de travail TPU.
La bibliothèque de surveillance des TPU vous fournit des informations détaillées sur les performances des charges de travail de machine learning sur le matériel TPU. Elle est conçue pour vous aider à comprendre votre utilisation des TPU, à identifier les goulots d'étranglement et à déboguer les problèmes de performances. Elle fournit des informations plus détaillées que les métriques d'interruption, de débit utile et autres.
Premiers pas avec la bibliothèque de surveillance des TPU
Il est facile d'accéder à ces précieux insights. La fonctionnalité de surveillance des TPU est intégrée au SDK LibTPU. Elle est donc incluse lorsque vous installez LibTPU.
Installer LibTPU
pip install libtpu
Les mises à jour de LibTPU sont coordonnées avec les versions de JAX. Cela signifie que lorsque vous installez la dernière version de JAX (publiée chaque mois), vous êtes généralement redirigé vers la dernière version compatible de LibTPU afin de bénéficier de ses fonctionnalités.
Installer JAX
pip install -U "jax[tpu]"
Pour les utilisateurs de PyTorch, l'installation de PyTorch/XLA permet d'obtenir les dernières fonctionnalités de LibTPU et de surveillance des TPU.
Installer PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Pour en savoir plus sur l'installation de PyTorch/XLA, consultez la section Installation dans le dépôt GitHub de PyTorch/XLA.
Importer la bibliothèque dans le code Python
Pour commencer à utiliser la bibliothèque de surveillance des TPU, vous devez importer le module libtpu
dans votre code Python.
from libtpu.sdk import tpumonitoring
Lister toutes les fonctionnalités disponibles
Pour lister tous les noms de métriques et les fonctionnalités qu'elles prennent en charge :
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
Métriques acceptées
L'exemple de code suivant montre comment recenser tous les noms de métriques acceptés :
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
Le tableau suivant présente toutes les métriques et les définitions correspondantes :
Métrique | Définition | Nom de la métrique pour l'API | Exemples de valeurs |
---|---|---|---|
Utilisation des Tensor Cores | Mesure le pourcentage d'utilisation de votre TensorCore (dérivé du pourcentage des opérations faisant partie des opérations TensorCore). Cette valeur est échantillonnée toutes les secondes pendant 10 microsecondes. Vous ne pouvez pas modifier le taux d'échantillonnage. Cette métrique vous permet de surveiller l'efficacité de vos charges de travail sur les appareils TPU. |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44'] # pourcentage d'utilisation pour les ID d'accélérateur 0 à 3 |
Pourcentage du cycle d'utilisation | Pourcentage de temps au cours de la dernière période d'échantillonnage (toutes les cinq secondes ; peut être ajusté en définissant l'option LIBTPU_INIT_ARG ) pendant lequel l'accélérateur a été en mode de traitement actif (enregistré avec les cycles utilisés pour exécuter les programmes HLO au cours de la dernière période d'échantillonnage). Cette métrique représente la charge d'un TPU. Elle est émise par puce.
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00'] # pourcentage de cycle d'utilisation pour les ID d'accélérateur 0 à 3 |
Capacité totale de la mémoire HBM | Cette métrique indique la capacité totale de la mémoire HBM en octets. |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000'] # Capacité totale de mémoire HBM (en octets) associée aux ID d'accélérateur 0 à 3 |
Utilisation de la capacité de mémoire HBM | Cette métrique indique l'utilisation de la capacité de mémoire HBM (en octets) au cours de la période d'échantillonnage précédente (toutes les cinq secondes ; peut être ajustée en définissant l'option LIBTPU_INIT_ARG ).
|
hbm_capacity_usage
|
['100', '200', '300', '400'] # Utilisation de la capacité de mémoire HBM (en octets) associée aux ID d'accélérateur 0 à 3 |
Latence de transfert de la mémoire tampon | Latences de transfert réseau pour le trafic multitranche à très grande échelle. Cette visualisation vous permet de comprendre l'environnement global des performances réseau. |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"] # taille du tampon, moyenne, centiles (p50, p90, p99 et p99.9) de la distribution de la latence du transfert réseau |
Métriques de haut niveau pour la distribution du temps d'exécution des opérations | Ces métriques fournissent des insights précis sur les performances de l'état d'exécution du binaire compilé HLO, ce qui permet de détecter les régressions et de déboguer au niveau du modèle. |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"] # Distribution de la durée d'exécution HLO pour CoreType-CoreID avec moyenne, centiles (p50, p90, p99, p99.9) |
Taille de la file d'attente de l'optimiseur de haut niveau | La surveillance de la taille de la file d'exécution HLO permet de suivre le nombre de programmes HLO compilés en attente ou en cours d'exécution. Cette métrique révèle toute congestion du pipeline d'exécution, ce qui permet d'identifier les goulots d'étranglement entraînant une baisse de performances dans l'exécution matérielle, une éventuelle surcharge du pilote ou un problème d'allocation de ressources. |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"] # Mesure la taille de la file d'attente pour CoreType-CoreID. |
Latence collective de bout en bout | Cette métrique mesure la latence collective de bout en bout sur le DCN en microsecondes, du lancement de l'opération par l'hôte jusqu'à la réception du résultat par tous les pairs. Cela inclut la réduction des données côté hôte et l'envoi de la sortie au TPU. Les résultats sont des chaînes qui détaillent la taille du tampon, le type et les latences moyennes (p50, P90, p95 et p99.9). |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …] # Taille du transfert-opération collective, moyenne, centiles (p50, p90, p95, p99.9) de la latence collective de bout en bout |
Lire les données de métrique : mode instantané
Pour activer le mode instantané, spécifiez le nom de la métrique lorsque vous appelez la fonction tpumonitoring.get_metric
. Le mode instantané vous permet d'insérer des vérifications de métriques ad hoc dans un code présentant de faibles performances pour déterminer si les problèmes de performances proviennent du logiciel ou du matériel.
L'exemple de code suivant montre comment utiliser le mode instantané pour lire duty_cycle
.
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
Accéder aux métriques à l'aide de la CLI
La procédure suivante montre comment interagir avec les métriques LibTPU à l'aide de la CLI :
Installez
tpu-info
:pip install tpu-info
# Access help information of tpu-info tpu-info --help / -h
Exécutez la vision par défaut de
tpu-info
:tpu-info
Le résultat ressemble à ce qui suit :
TPU Chips
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓
┃ Chip ┃ Type ┃ Devices ┃ PID ┃
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩
│ /dev/accel0 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel1 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel2 │ TPU v4 chip │ 1 │ 130007 │
│ /dev/accel3 │ TPU v4 chip │ 1 │ 130007 │
└─────────────┴─────────────┴─────────┴────────┘
TPU Runtime Utilization
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Device ┃ Memory usage ┃ Duty cycle ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ 0 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 1 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 2 │ 0.00 GiB / 31.75 GiB │ 0.00% │
│ 3 │ 0.00 GiB / 31.75 GiB │ 0.00% │
└────────┴──────────────────────┴────────────┘
TensorCore Utilization
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Chip ID ┃ TensorCore Utilization ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 0 │ 0.00% │
│ 1 │ 0.00% │
│ 3 │ 0.00% │
│ 2 │ 0.00% |
└─────────┴────────────────────────┘
Buffer Transfer Latency
┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓
┃ Buffer Size ┃ P50 ┃ P90 ┃ P95 ┃ P999 ┃
┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩
│ 8MB+ | 0us │ 0us │ 0us │ 0us |
└─────────────┴─────┴─────┴─────┴──────┘
Utiliser des métriques pour vérifier l'utilisation des TPU
Les exemples suivants montrent comment utiliser les métriques de la bibliothèque de surveillance des TPU pour suivre l'utilisation des TPU.
Surveiller le cycle d'utilisation des TPU pendant l'entraînement JAX
Scénario : Vous exécutez un script d'entraînement JAX et souhaitez surveiller la métrique duty_cycle_pct
des TPU tout au long du processus d'entraînement pour confirmer que ceux-ci sont utilisés efficacement. Vous pouvez consigner cette métrique périodiquement pendant l'entraînement pour suivre l'utilisation des TPU.
L'exemple de code suivant montre comment surveiller le cycle d'utilisation des TPU pendant l'entraînement JAX :
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
Vérifier l'utilisation de la mémoire HBM avant d'exécuter l'inférence JAX
Scénario : Avant d'exécuter l'inférence avec votre modèle JAX, vérifiez l'utilisation actuelle de la mémoire HBM (High Bandwidth Memory) sur le TPU. Cela permet de vous assurer que vous disposez de suffisamment de mémoire et d'obtenir une mesure de référence avant de lancer l'inférence.
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
Fréquence d'exportation des métriques de TPU
La fréquence d'actualisation des métriques de TPU ne peut pas être inférieure à une seconde. Les données de métriques d'hôte sont exportées à une fréquence fixe de 1 Hz. La latence introduite par ce processus d'exportation est négligeable. Les métriques d'exécution de LibTPU ne sont pas soumises à la même contrainte de fréquence. Toutefois, par souci de cohérence, ces métriques sont également échantillonnées à 1 Hz, soit un échantillon par seconde.