ספריית מעקב אחרי TPU
בעזרת יכולות מתקדמות של ניטור TPU, שמבוססות ישירות על שכבת התוכנה הבסיסית LibTPU, תוכלו לקבל תובנות מעמיקות לגבי הביצועים וההתנהגות של חומרת Cloud TPU. ספריית LibTPU כוללת מנהלי התקנים, ספריות רשת, קומפיילר XLA וזמן ריצה של TPU לצורך אינטראקציה עם יחידות TPU. עם זאת, המסמך הזה מתמקד בספריית המעקב של TPU.
ספריית המעקב של TPU מספקת:
יכולת תצפית מקיפה: גישה ל-API של טלמטריה ולחבילת מדדים, שמספקים תובנות מפורטות לגבי הביצועים התפעוליים וההתנהגויות הספציפיות של יחידות ה-TPU.
ערכות כלים לאבחון: ערכות הכלים האלה כוללות SDK וממשק שורת פקודה (CLI) שנועדו לאפשר ניפוי באגים וניתוח מעמיק של הביצועים של משאבי ה-TPU.
תכונות המעקב האלה נועדו להיות פתרון ברמה גבוהה שפונה ללקוחות, ומספקות לכם את הכלים החיוניים לאופטימיזציה יעילה של עומסי העבודה של TPU.
ספריית המעקב של TPU מספקת מידע מפורט על הביצועים של עומסי עבודה של למידת מכונה בחומרת TPU. הוא נועד לעזור לכם להבין את השימוש ב-TPU, לזהות צווארי בקבוק ולפתור בעיות בביצועים. הוא מספק מידע מפורט יותר ממדדי ההפרעות, ממדדי התפוקה וממדדים אחרים.
תחילת העבודה עם ספריית המעקב אחר TPU
הגישה לתובנות החשובות האלה היא פשוטה. הפונקציונליות של מעקב אחרי TPU משולבת ב-LibTPU SDK, ולכן היא כלולה כשמתקינים את LibTPU.
התקנת LibTPU
pip install libtpu
לחלופין, העדכונים של LibTPU מתואמים עם הגרסאות של JAX, כלומר כשמתקינים את הגרסה העדכנית של JAX (שמתפרסמת מדי חודש), בדרך כלל היא תהיה מקושרת לגרסה התואמת העדכנית של LibTPU ולתכונות שלה.
התקנה של JAX
pip install -U "jax[tpu]"
משתמשי PyTorch יכולים להתקין את PyTorch/XLA כדי לקבל את הגרסה העדכנית של LibTPU ואת הפונקציונליות של ניטור TPU.
התקנה של PyTorch/XLA
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
מידע נוסף על התקנה של PyTorch/XLA זמין במאגר GitHub של PyTorch/XLA בקטע Installation.
ייבוא הספרייה ב-Python
כדי להתחיל להשתמש בספריית המעקב של TPU, צריך לייבא את מודול libtpu
לקוד Python.
from libtpu.sdk import tpumonitoring
הצגת רשימה של כל הפונקציות הנתמכות
רשימה של כל שמות המדדים והפונקציונליות שהם תומכים בה:
from libtpu.sdk import tpumonitoring
tpumonitoring.help()
" libtpu.sdk.monitoring.help():
List all supported functionality.
libtpu.sdk.monitoring.list_support_metrics()
List support metric names in the list of str format.
libtpu.sdk.monitoring.get_metric(metric_name:str)
Get metric data with metric name. It represents the snapshot mode.
The metric data is a object with `description()` and `data()` methods,
where the `description()` returns a string describe the format of data
and data unit, `data()` returns the metric data in the list in str format.
"
מדדים נתמכים
בדוגמת הקוד הבאה אפשר לראות איך מציגים רשימה של כל שמות המדדים הנתמכים:
from libtpu.sdk import tpumonitoring
tpumonitoring.list_supported_metrics()
["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]
בטבלה הבאה מוצגים כל המדדים וההגדרות שלהם:
| מדד | הגדרה | שם המדד ל-API | ערכים לדוגמה |
|---|---|---|---|
| Tensor Core Utilization | המדד הזה מודד את אחוז השימוש ב-TensorCore, ומחושב כאחוז הפעולות שמהוות חלק מהפעולות של TensorCore. הדגימה מתבצעת כל שנייה, למשך 10 מיקרו-שניות. אי אפשר לשנות את קצב הדגימה. המדד הזה מאפשר לעקוב אחרי היעילות של עומסי העבודה במכשירי TPU. |
tensorcore_util
|
['1.11', '2.22', '3.33', '4.44']
# utilization percentage for accelerator ID 0-3. |
| אחוז הדיוטי סייקל | אחוז הזמן במהלך תקופת הדגימה האחרונה (כל 5 שניות; אפשר לשנות את ההגדרה באמצעות האפשרות LIBTPU_INIT_ARG) שבו המאיץ עיבד באופן פעיל (נרשם עם מחזורי העיבוד ששימשו להפעלת תוכניות HLO במהלך תקופת הדגימה האחרונה). המדד הזה
מייצג את רמת העומס של TPU. המדד מופק לכל שבב.
|
duty_cycle_pct
|
['10.00', '20.00', '30.00', '40.00']
# Duty cycle percentage for accelerator ID 0-3. |
| קיבולת HBM כוללת | המדד הזה מציג את קיבולת ה-HBM הכוללת בבייט. |
hbm_capacity_total
|
['30000000000', '30000000000', '30000000000', '30000000000']
# Total HBM capacity in bytes that attached to accelerator ID 0-3. |
| HBM Capacity Usage | המדד הזה מציג את השימוש בקיבולת HBM בבייט במהלך תקופת הדגימה האחרונה (כל 5 שניות, אפשר לשנות את זה באמצעות ההגדרה של הדגל LIBTPU_INIT_ARG).
|
hbm_capacity_usage
|
['100', '200', '300', '400']
# Capacity usage for HBM in bytes that attached to accelerator ID 0-3. |
| זמן האחזור של העברת מאגר הנתונים הזמני | זמני האחזור של העברת נתונים ברשת לתעבורת נתונים של כמה פרוסות בקנה מידה גדול. ההדמיה הזו מאפשרת לכם להבין את סביבת הביצועים הכוללת של הרשת. |
buffer_transfer_latency
|
["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]
# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution. |
| מדדים של התפלגות זמן הביצוע של פעולות ברמה גבוהה | מספק תובנות מפורטות לגבי הביצועים של סטטוס ההפעלה של קובץ בינארי שעבר קומפילציה של HLO, ומאפשר זיהוי רגרסיה וניפוי באגים ברמת המודל. |
hlo_exec_timing
|
["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]
# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999. |
| גודל התור של כלי האופטימיזציה ברמה גבוהה | במסגרת המעקב אחרי גודל תור הביצוע של HLO, המערכת עוקבת אחרי מספר תוכניות HLO שעברו קומפילציה וממתינות לביצוע או נמצאות בתהליך ביצוע. המדד הזה חושף עומס בצינור הביצוע, ומאפשר לזהות צווארי בקבוק בביצועים בביצוע חומרה, בעומס יתר של מנהלי התקנים או בהקצאת משאבים. |
hlo_queue_size
|
["tensorcore-0: 1", "tensorcore-1: 2"]
# Measures queue size for CoreType-CoreID. |
| זמן האחזור הכולל מקצה לקצה | המדד הזה מודד את זמן האחזור הכולל מקצה לקצה ברשת DCN במיקרו-שניות, מהמארח שמתחיל את הפעולה ועד שכל העמיתים מקבלים את הפלט. הוא כולל צמצום נתונים בצד המארח ושליחת פלט ל-TPU. התוצאות הן מחרוזות שמפרטות את שטח האחסון הזמני, הסוג והממוצע, וזמני הטעינה P50, P90, P95 ו-P99.9. |
collective_e2e_latency
|
["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]
# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency. |
| זמן האחזור של הלוך ושוב בשכבת התעבורה | התפלגות של זמני הלוך ושוב (RTT) מינימליים שנצפו בחיבורי TCP שמשמשים את gRPC לתנועת TPU מרובת פרוסות. |
grpc_tcp_min_round_trip_times
|
['27.63, 29.03, 38.52, 41.63, 52.74']
# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs). |
| Throughput at Transport Layer | התפלגות מצטברת של קצב העברת הנתונים האחרון של חיבורי TCP שנעשה בהם שימוש ב-gRPC לתנועת נתונים של TPU עם כמה פרוסות. |
grpc_tcp_delivery_rates
|
['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']
# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs). |
קריאת נתוני מדדים
כדי לקרוא נתוני מדדים, מציינים את שם המדד כשקוראים לפונקציה tpumonitoring.get_metric. אפשר להוסיף בדיקות אד-הוק של מדדים לקוד עם ביצועים נמוכים כדי לזהות אם בעיות בביצועים נובעות מתוכנה או מחומרה.
בדוגמת הקוד הבאה אפשר לראות איך קוראים את מדד duty_cycle:
from libtpu.sdk import tpumonitoring
metric = tpumonitoring.get_metric("duty_cycle_pct")
metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."
metric.data()
["0.00", "0.00", "0.00", "0.00"]
# accelerator_0-3
שימוש במדדים כדי לבדוק את ניצול ה-TPU
בדוגמאות הבאות מוסבר איך להשתמש במדדים מ-TPU Monitoring Library כדי לעקוב אחרי השימוש ב-TPU.
מעקב אחר מחזור הפעילות של TPU במהלך אימון ב-JAX
תרחיש: אתם מריצים סקריפט לאימון JAX ורוצים לעקוב אחרי מדד ה-TPU duty_cycle_pct לאורך תהליך האימון כדי לוודא שהשימוש ב-TPU יעיל. אפשר לרשום את המדד הזה ביומן באופן תקופתי במהלך האימון כדי לעקוב אחרי ניצול ה-TPU.
בדוגמת הקוד הבאה אפשר לראות איך עוקבים אחרי מחזור הפעילות של TPU במהלך אימון ב-JAX:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time
# --- Your JAX model and training setup would go here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
def loss_fn(params, x, y):
preds = simple_model(x)
return jnp.mean((preds - y)**2)
def train_step(params, x, y, optimizer):
grads = jax.grad(loss_fn)(params, x, y)
return optimizer.update(grads, params)
key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))
num_epochs = 10
log_interval_steps = 2 # Log duty cycle every 2 steps
for epoch in range(num_epochs):
for step in range(5): # Example steps per epoch
params = train_step(params, data_x, data_y, optimizer)
if (step + 1) % log_interval_steps == 0:
# --- Integrate TPU Monitoring Library here to get duty_cycle ---
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
# --- End TPU Monitoring Library Integration ---
# --- Rest of your training loop logic ---
time.sleep(0.1) # Simulate some computation
print("Training complete.")
בדיקת השימוש ב-HBM לפני הפעלת מסקנות JAX
תרחיש: לפני שמריצים הסקה עם מודל JAX, בודקים את השימוש הנוכחי ב-HBM (זיכרון עם רוחב פס גבוה) ב-TPU כדי לוודא שיש מספיק זיכרון זמין וכדי לקבל מדידת בסיס לפני שההסקה מתחילה.
# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and inference setup would go here ---
# --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
return jnp.sum(x)
key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters
# Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f" Description: {hbm_util_metric.description}")
print(f" Data: {hbm_util_data}")
# End TPU Monitoring Library Integration
# Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)
print("Inference complete.")
בדיקת מדדי הרשת
תרחיש: אתם מריצים עומס עבודה מרובה מארחים ומרובה פרוסות, ואתם רוצים להתחבר לאחד מ-pods של GKE או ל-TPU באמצעות SSH כדי לראות את מדדי הרשת בזמן שעומס העבודה פועל. אפשר גם לשלב את הפקודות ישירות בעומס העבודה של כמה מארחים.
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
# --- Your JAX model and training setup goes here ---
# --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
return jnp.sum(x)
# --- details here ---
# ==============================================================================
# Metric 1: TCP Delivery Rate
# ==============================================================================
# This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)
# Get the metric object
delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)
# Print the description provided by the library
print("Description:", delivery_rate_metric.description())
# Print the actual data payload
print("Data:", delivery_rate_metric.data())
# ==============================================================================
# Metric 2: TCP Minimum Round Trip Time (RTT)
# ==============================================================================
# This metric reflects the minimum RTT measured between sending a TCP packet
# and receiving the acknowledgement.
# The output is a list of strings representing latency statistics:
# [mean, p50, p90, p95, p99.9]
# Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)
# Get the metric object
min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)
# Print the description provided by the library
print("Description:", min_rtt_metric.description())
# Print the actual data payload
print("Data:", min_rtt_metric.data())
תדירות הרענון של מדדי TPU
תדירות הרענון של מדדי TPU מוגבלת למינימום של שנייה אחת. נתוני המדדים של המארח מיוצאים בתדירות קבועה של 1 הרץ. זמן האחזור שנוצר בתהליך הייצוא הזה הוא זניח. מדדי זמן ריצה מ-LibTPU לא כפופים לאותה מגבלת תדירות. עם זאת, כדי לשמור על עקביות, המדדים האלה נדגמים גם הם בתדירות של 1 הרץ או דגימה אחת לשנייה.
מודול TPU-Z
TPU-Z הוא כלי לטלמטריה ולניפוי באגים ב-TPU. הוא מספק מידע מפורט על סטטוס זמן הריצה של כל ליבות ה-TPU שמצורפות למארח. הפונקציונליות מסופקת דרך המודול tpuz, שהוא חלק מהמודול libtpu.sdk ב-libtpu Python SDK. המודול מספק תמונת מצב של הסטטוס של כל ליבה.
תרחיש השימוש העיקרי ב-TPU-Z הוא אבחון של תקיעות או חסימות בנקודות קריטיות בעומסי עבודה מבוזרים של TPU. אתם יכולים להריץ שאילתות בשירות TPU-Z במארחים כדי לתעד את המצב של כל ליבה, להשוות בין מוני התוכניות, מיקומי ה-HLO ומזהי ההרצה בכל הליבות כדי לזהות אנומליות.
כדי להציג את המדדים של TPU-Z, משתמשים בפונקציה get_core_state_summary() בספרייה libtpu.sdk:
summary = sdk.tpuz.get_core_state_summary()
הפלט של מדדי TPU-Z מסופק כמילון. זו דוגמה קטומה לליבה יחידה:
{
"host_name": "my-tpu-host-vm",
"core_states": {
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 4490,
"program_id": 3274167277388825310,
"run_id": 3
}
],
"program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
"queued_program_info": [],
"error_message": ""
}
// ...
}
}
כדי לאחזר מידע על High-Level Optimizers (HLO) בכל ליבה, צריך להגדיר את הפרמטר include_hlo_info לערך True:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
הפלט כולל מידע נוסף על HLO:
"1": {
"core_id": {
"global_core_id": 1,
"chip_id": 0,
"core_on_chip": {
"type": "TPU_CORE_TYPE_TENSOR_CORE",
"index": 1
}
},
"sequencer_info": [
{
"sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
"sequencer_index": 0,
"pc": 17776,
"tag": 3,
"tracemark": 2147483646,
"program_id": 3230481660274331500,
"run_id": 81,
"hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
"hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
}
],
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
"launch_id": 1394130914,
"queued_program_info": [
{
"run_id": 81,
"launch_id": 1394130914,
"program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
}
]
}
מדדים של TPU-Z
הפונקציה get_core_state_summary מחזירה מדדים של TPU-Z בצורה של מילון עם המבנה הבא.
CurrentCoreStateSummary
CurrentCoreStateSummary המילון מספק סיכום מפורט של מצב ליבת TPU ספציפית.
| שדה | סוג | תיאור |
|---|---|---|
core_id |
מילון | מילון TpuCoreIdentifier שמכיל מידע על מזהה ליבת ה-TPU. |
sequencer_info |
רשימת המילונים | רשימה של SequencerInfo מילונים, שמתארים את המצב של כל רכיב ליצירת רצפים בליבה. |
program_fingerprint |
בייטים | טביעת האצבע של התוכנית שמופעלת בליבה הזו. |
launch_id |
מספר שלם | מזהה ההפעלה של התוכנית הנוכחית או האחרונה. |
queued_program_info |
רשימת המילונים | רשימה של QueuedProgramInfo מילונים לתוכניות שנמצאות בתור להרצה. |
error_message |
מחרוזת | הודעות שגיאה לגבי הליבה הזו. |
TpuCoreIdentifier
המילון TpuCoreIdentifier מספק מידע על מזהים של ליבות במערכת TPU.
| שדה | סוג | תיאור |
|---|---|---|
global_core_id |
מספר שלם | המזהה של הליבה. |
chip_id |
מספר שלם | המזהה של הצ'יפ שאליו שייכת הליבה. |
core_on_chip |
מילון | מילון TpuCoreOnChip שמתאר את הסוג של הליבה ואת האינדקס שלה בשבב. |
TpuCoreOnChip
מילון TpuCoreOnChip מכיל מידע על המאפיינים של ליבה מסוימת בתוך צ'יפ מסוים.
| שדה | סוג | תיאור |
|---|---|---|
type |
מחרוזת | סוג ליבת ה-TPU. לדוגמה: TPU_CORE_TYPE_TENSOR_CORE. |
index |
מספר שלם | האינדקס של הליבה בשבב. |
SequencerInfo
המילון SequencerInfo מכיל מידע על המצב של רכיב sequencer יחיד בליבה.
| שדה | סוג | תיאור |
|---|---|---|
sequencer_type |
מחרוזת | סוג הרצף. לדוגמה: TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER. |
sequencer_index |
מספר שלם | האינדקס של הרכיב ליצירת רצף (אם יש כמה רכיבים מאותו סוג). |
pc |
מספר שלם | הערך הנוכחי של מונה התוכנית. |
program_id |
מספר שלם | המזהה שמשויך למופע ספציפי של תוכנית שמופעלת לביצוע בליבת TPU. |
run_id |
מספר שלם | מזהה הריצה שמשויך למופע ספציפי של הפעלת תוכנית בליבת TPU. |
hlo_location |
מחרוזת | נתוני מיקום של כלי האופטימיזציה ברמה גבוהה. |
hlo_detailed_info |
מחרוזת | מידע מפורט על הכלי לשיפור הביצועים ברמה גבוהה. |
QueuedProgramInfo
המילון QueuedProgramInfo מכיל מידע על תוכניות שנמצאות בתור להרצה בליבה.
| שדה | סוג | תיאור |
|---|---|---|
run_id |
מספר שלם | מזהה ההפעלה של התוכנית שהוכנסה לתור. |
launch_id |
מספר שלם | מזהה ההשקה של התוכנית שנוספה לתור. |
program_fingerprint |
בייטים | טביעת האצבע של התוכנית בתור. |
TPU-Z עם JAX
אפשר לגשת למדדים של TPU-Z בעומסי עבודה של JAX באמצעות הספרייה libtpu.sdk.
סקריפט Python הבא משתמש ב-JAX לחישוב טנסורים עם ביצועים גבוהים, ובמקביל משתמש ב-libtpu SDK בשרשור ברקע כדי לעקוב אחרי המצב והפעילות של חומרת ה-TPU הבסיסית.
צריך לכלול את חבילות Python הבאות:
import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk
הפונקציה monitor_tpu_status משתמשת בשרשור ברקע כדי להציג באופן רציף את סטטוס הפעולה של ליבות ה-TPU, בזמן שהאפליקציה הראשית מבצעת עומס עבודה של JAX. הוא משמש ככלי לאבחון בעיות בזמן אמת.
def monitor_tpu_status():
"""Monitors TPU status in a background thread."""
while monitoring_active:
try:
summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
if summary and 'core_states' in summary:
print(summary)
else:
print('WARNING: Call returned an empty or invalid summary.')
except RuntimeError as e:
print(f'FAIL: Error calling API: {e}')
except Exception as e:
print(f'FAIL: Unexpected error in monitor thread: {e}')
for _ in range(MONITORING_INTERVAL_SECONDS * 2):
if not monitoring_active:
break
time.sleep(0.5)
print('✅ Monitoring thread stopped.')
הפונקציה transformer_block מיישמת שכבה מלאה של ארכיטקטורת ה-Transformer, שהיא אבן הבניין הבסיסית של מודלים של LLM.
@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
"""A simplified but computationally intensive Transformer block."""
# Multi-head Self-Attention
qkv = jnp.dot(x, params['qkv_kernel'])
q, k, v = jnp.array_split(qkv, 3, axis=-1)
# Reshape for multi-head attention
q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
# Scaled dot-product attention
attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
attention_weights = jax.nn.softmax(attention_scores, axis=-1)
attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)
attention_output = jnp.dot(attention_output, params['o_kernel'])
# Residual connection and Layer Normalization 1
h1 = x + attention_output
h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
h1_norm = h1_norm / jnp.sqrt(
jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
)
# Feed-Forward Network
ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])
# Residual connection and Layer Normalization 2
h2 = h1_norm + ffn_output
h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
h2_norm = h2_norm / jnp.sqrt(
jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
)
return h2_norm
הפונקציה main מתזמנת את ההגדרה של החישוב ב-JAX, מפעילה את המעקב ברקע של TPU ומריצה את הלולאה הראשית של עומס העבודה.
def main():
num_devices = jax.device_count()
print(f"Running on {num_devices} devices.")
batch_size = 128 * num_devices
seq_len = 512
embed_dim = 1024
ffn_dim = embed_dim * 4
key = jax.random.PRNGKey(0)
params = {
'qkv_kernel': jax.random.normal(
key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
),
'o_kernel': jax.random.normal(
key, (embed_dim, embed_dim), dtype=jnp.bfloat16
),
'ffn1_kernel': jax.random.normal(
key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
),
'ffn2_kernel': jax.random.normal(
key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
),
}
input_data = jax.random.normal(
key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
)
input_data = jax.device_put(input_data)
monitor_thread = threading.Thread(target=monitor_tpu_status)
monitor_thread.start()
print("Starting JAX computation loop...")
start_time = time.time()
iterations = 0
while time.time() - start_time < JOB_DURATION_SECONDS:
result = transformer_block(params, input_data)
result.block_until_ready()
iterations += 1
print(f' -> Jax iteration {iterations} complete.', end='\r')
print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")
global monitoring_active
monitoring_active = False
monitor_thread.join()
if __name__ == '__main__':
main()
פתרון בעיות
בקטע הזה מופיע מידע לפתרון בעיות שיעזור לכם לזהות ולפתור בעיות שאתם עשויים להיתקל בהן במהלך השימוש בספריית המעקב של TPU.
תכונות או מדדים חסרים
אם אתם לא מצליחים לראות חלק מהתכונות או מהמדדים, הסיבה הכי נפוצה לכך היא גרסה לא עדכנית של libtpu. התכונות והמדדים של ספריית המעקב אחרי TPU כלולים בגרסאות libtpu, ובגרסאות ישנות יותר יכול להיות שחסרים מדדים ותכונות חדשים.
בודקים את הגרסה של libtpu שפועלת בסביבה שלכם:
שורת פקודה:
pip show libtpu
Python:
import libtpu
print(libtpu.__version__)
אם אתם לא משתמשים בגרסה האחרונה של libtpu, אתם יכולים להשתמש בפקודה הבאה כדי לעדכן את הספרייה:
pip install --upgrade libtpu