ספריית מעקב אחרי TPU

בעזרת יכולות מתקדמות של ניטור TPU, שמבוססות ישירות על שכבת התוכנה הבסיסית LibTPU, תוכלו לקבל תובנות מעמיקות לגבי הביצועים וההתנהגות של חומרת Cloud TPU. ספריית LibTPU כוללת מנהלי התקנים, ספריות רשת, קומפיילר XLA וזמן ריצה של TPU לצורך אינטראקציה עם יחידות TPU. עם זאת, המסמך הזה מתמקד בספריית המעקב של TPU.

ספריית המעקב של TPU מספקת:

  • יכולת תצפית מקיפה: גישה ל-API של טלמטריה ולחבילת מדדים, שמספקים תובנות מפורטות לגבי הביצועים התפעוליים וההתנהגויות הספציפיות של יחידות ה-TPU.

  • ערכות כלים לאבחון: ערכות הכלים האלה כוללות SDK וממשק שורת פקודה (CLI) שנועדו לאפשר ניפוי באגים וניתוח מעמיק של הביצועים של משאבי ה-TPU.

תכונות המעקב האלה נועדו להיות פתרון ברמה גבוהה שפונה ללקוחות, ומספקות לכם את הכלים החיוניים לאופטימיזציה יעילה של עומסי העבודה של TPU.

ספריית המעקב של TPU מספקת מידע מפורט על הביצועים של עומסי עבודה של למידת מכונה בחומרת TPU. הוא נועד לעזור לכם להבין את השימוש ב-TPU, לזהות צווארי בקבוק ולפתור בעיות בביצועים. הוא מספק מידע מפורט יותר ממדדי ההפרעות, ממדדי התפוקה וממדדים אחרים.

תחילת העבודה עם ספריית המעקב אחר TPU

הגישה לתובנות החשובות האלה היא פשוטה. הפונקציונליות של מעקב אחרי TPU משולבת ב-LibTPU SDK, ולכן היא כלולה כשמתקינים את LibTPU.

התקנת LibTPU

pip install libtpu

לחלופין, העדכונים של LibTPU מתואמים עם הגרסאות של JAX, כלומר כשמתקינים את הגרסה העדכנית של JAX (שמתפרסמת מדי חודש), בדרך כלל היא תהיה מקושרת לגרסה התואמת העדכנית של LibTPU ולתכונות שלה.

התקנה של JAX

pip install -U "jax[tpu]"

משתמשי PyTorch יכולים להתקין את PyTorch/XLA כדי לקבל את הגרסה העדכנית של LibTPU ואת הפונקציונליות של ניטור TPU.

התקנה של PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

מידע נוסף על התקנה של PyTorch/XLA זמין במאגר GitHub של PyTorch/XLA בקטע Installation.

ייבוא הספרייה ב-Python

כדי להתחיל להשתמש בספריית המעקב של TPU, צריך לייבא את מודול libtpu לקוד Python.

from libtpu.sdk import tpumonitoring

הצגת רשימה של כל הפונקציות הנתמכות

רשימה של כל שמות המדדים והפונקציונליות שהם תומכים בה:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

מדדים נתמכים

בדוגמת הקוד הבאה אפשר לראות איך מציגים רשימה של כל שמות המדדים הנתמכים:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

בטבלה הבאה מוצגים כל המדדים וההגדרות שלהם:

מדד הגדרה שם המדד ל-API ערכים לדוגמה
Tensor Core Utilization המדד הזה מודד את אחוז השימוש ב-TensorCore, ומחושב כאחוז הפעולות שמהוות חלק מהפעולות של TensorCore. הדגימה מתבצעת כל שנייה, למשך 10 מיקרו-שניות. אי אפשר לשנות את קצב הדגימה. המדד הזה מאפשר לעקוב אחרי היעילות של עומסי העבודה במכשירי TPU. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# utilization percentage for accelerator ID 0-3.
אחוז הדיוטי סייקל אחוז הזמן במהלך תקופת הדגימה האחרונה (כל 5 שניות; אפשר לשנות את ההגדרה באמצעות האפשרות LIBTPU_INIT_ARG) שבו המאיץ עיבד באופן פעיל (נרשם עם מחזורי העיבוד ששימשו להפעלת תוכניות HLO במהלך תקופת הדגימה האחרונה). המדד הזה מייצג את רמת העומס של TPU. המדד מופק לכל שבב. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Duty cycle percentage for accelerator ID 0-3.
קיבולת HBM כוללת המדד הזה מציג את קיבולת ה-HBM הכוללת בבייט. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Total HBM capacity in bytes that attached to accelerator ID 0-3.
HBM Capacity Usage המדד הזה מציג את השימוש בקיבולת HBM בבייט במהלך תקופת הדגימה האחרונה (כל 5 שניות, אפשר לשנות את זה באמצעות ההגדרה של הדגל LIBTPU_INIT_ARG). hbm_capacity_usage ['100', '200', '300', '400']

# Capacity usage for HBM in bytes that attached to accelerator ID 0-3.
זמן האחזור של העברת מאגר הנתונים הזמני זמני האחזור של העברת נתונים ברשת לתעבורת נתונים של כמה פרוסות בקנה מידה גדול. ההדמיה הזו מאפשרת לכם להבין את סביבת הביצועים הכוללת של הרשת. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution.
מדדים של התפלגות זמן הביצוע של פעולות ברמה גבוהה מספק תובנות מפורטות לגבי הביצועים של סטטוס ההפעלה של קובץ בינארי שעבר קומפילציה של HLO, ומאפשר זיהוי רגרסיה וניפוי באגים ברמת המודל. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999.
גודל התור של כלי האופטימיזציה ברמה גבוהה במסגרת המעקב אחרי גודל תור הביצוע של HLO, המערכת עוקבת אחרי מספר תוכניות HLO שעברו קומפילציה וממתינות לביצוע או נמצאות בתהליך ביצוע. המדד הזה חושף עומס בצינור הביצוע, ומאפשר לזהות צווארי בקבוק בביצועים בביצוע חומרה, בעומס יתר של מנהלי התקנים או בהקצאת משאבים. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Measures queue size for CoreType-CoreID.
זמן האחזור הכולל מקצה לקצה המדד הזה מודד את זמן האחזור הכולל מקצה לקצה ברשת DCN במיקרו-שניות, מהמארח שמתחיל את הפעולה ועד שכל העמיתים מקבלים את הפלט. הוא כולל צמצום נתונים בצד המארח ושליחת פלט ל-TPU. התוצאות הן מחרוזות שמפרטות את שטח האחסון הזמני, הסוג והממוצע, וזמני הטעינה P50,‏ P90,‏ P95 ו-P99.9. collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency.
זמן האחזור של הלוך ושוב בשכבת התעבורה התפלגות של זמני הלוך ושוב (RTT) מינימליים שנצפו בחיבורי TCP שמשמשים את gRPC לתנועת TPU מרובת פרוסות. grpc_tcp_min_round_trip_times ['27.63, 29.03, 38.52, 41.63, 52.74']

# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs).
Throughput at Transport Layer התפלגות מצטברת של קצב העברת הנתונים האחרון של חיבורי TCP שנעשה בהם שימוש ב-gRPC לתנועת נתונים של TPU עם כמה פרוסות. grpc_tcp_delivery_rates ['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']

# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs).

קריאת נתוני מדדים

כדי לקרוא נתוני מדדים, מציינים את שם המדד כשקוראים לפונקציה tpumonitoring.get_metric. אפשר להוסיף בדיקות אד-הוק של מדדים לקוד עם ביצועים נמוכים כדי לזהות אם בעיות בביצועים נובעות מתוכנה או מחומרה.

בדוגמת הקוד הבאה אפשר לראות איך קוראים את מדד duty_cycle:

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

שימוש במדדים כדי לבדוק את ניצול ה-TPU

בדוגמאות הבאות מוסבר איך להשתמש במדדים מ-TPU Monitoring Library כדי לעקוב אחרי השימוש ב-TPU.

מעקב אחר מחזור הפעילות של TPU במהלך אימון ב-JAX

תרחיש: אתם מריצים סקריפט לאימון JAX ורוצים לעקוב אחרי מדד ה-TPU‏ duty_cycle_pct לאורך תהליך האימון כדי לוודא שהשימוש ב-TPU יעיל. אפשר לרשום את המדד הזה ביומן באופן תקופתי במהלך האימון כדי לעקוב אחרי ניצול ה-TPU.

בדוגמת הקוד הבאה אפשר לראות איך עוקבים אחרי מחזור הפעילות של TPU במהלך אימון ב-JAX:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

בדיקת השימוש ב-HBM לפני הפעלת מסקנות JAX

תרחיש: לפני שמריצים הסקה עם מודל JAX, בודקים את השימוש הנוכחי ב-HBM (זיכרון עם רוחב פס גבוה) ב-TPU כדי לוודא שיש מספיק זיכרון זמין וכדי לקבל מדידת בסיס לפני שההסקה מתחילה.

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

בדיקת מדדי הרשת

תרחיש: אתם מריצים עומס עבודה מרובה מארחים ומרובה פרוסות, ואתם רוצים להתחבר לאחד מ-pods של GKE או ל-TPU באמצעות SSH כדי לראות את מדדי הרשת בזמן שעומס העבודה פועל. אפשר גם לשלב את הפקודות ישירות בעומס העבודה של כמה מארחים.

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

 # --- Your JAX model and training setup goes here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)
     # --- details here ---

    # ==============================================================================
    # Metric 1: TCP Delivery Rate
    # ==============================================================================
    # This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)

        # Get the metric object
        delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)

        # Print the description provided by the library
        print("Description:", delivery_rate_metric.description())

        # Print the actual data payload
        print("Data:", delivery_rate_metric.data())

    # ==============================================================================
    # Metric 2: TCP Minimum Round Trip Time (RTT)
    # ==============================================================================
    # This metric reflects the minimum RTT measured between sending a TCP packet
    # and receiving the acknowledgement.
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)

        # Get the metric object
        min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)

        # Print the description provided by the library
        print("Description:", min_rtt_metric.description())

        # Print the actual data payload
        print("Data:", min_rtt_metric.data())

תדירות הרענון של מדדי TPU

תדירות הרענון של מדדי TPU מוגבלת למינימום של שנייה אחת. נתוני המדדים של המארח מיוצאים בתדירות קבועה של 1 הרץ. זמן האחזור שנוצר בתהליך הייצוא הזה הוא זניח. מדדי זמן ריצה מ-LibTPU לא כפופים לאותה מגבלת תדירות. עם זאת, כדי לשמור על עקביות, המדדים האלה נדגמים גם הם בתדירות של 1 הרץ או דגימה אחת לשנייה.

מודול TPU-Z

‫TPU-Z הוא כלי לטלמטריה ולניפוי באגים ב-TPU. הוא מספק מידע מפורט על סטטוס זמן הריצה של כל ליבות ה-TPU שמצורפות למארח. הפונקציונליות מסופקת דרך המודול tpuz, שהוא חלק מהמודול libtpu.sdk ב-libtpu Python SDK. המודול מספק תמונת מצב של הסטטוס של כל ליבה.

תרחיש השימוש העיקרי ב-TPU-Z הוא אבחון של תקיעות או חסימות בנקודות קריטיות בעומסי עבודה מבוזרים של TPU. אתם יכולים להריץ שאילתות בשירות TPU-Z במארחים כדי לתעד את המצב של כל ליבה, להשוות בין מוני התוכניות, מיקומי ה-HLO ומזהי ההרצה בכל הליבות כדי לזהות אנומליות.

כדי להציג את המדדים של TPU-Z, משתמשים בפונקציה get_core_state_summary() בספרייה libtpu.sdk:

summary = sdk.tpuz.get_core_state_summary()

הפלט של מדדי TPU-Z מסופק כמילון. זו דוגמה קטומה לליבה יחידה:

{
  "host_name": "my-tpu-host-vm",
  "core_states": {
    "1": {
      "core_id": {
        "global_core_id": 1,
        "chip_id": 0,
        "core_on_chip": {
          "type": "TPU_CORE_TYPE_TENSOR_CORE",
          "index": 1
        }
      },
      "sequencer_info": [
        {
          "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
          "sequencer_index": 0,
          "pc": 4490,
          "program_id": 3274167277388825310,
          "run_id": 3
        }
      ],
      "program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
      "queued_program_info": [],
      "error_message": ""
    }
    // ...
  }
}

כדי לאחזר מידע על High-Level Optimizers ‏ (HLO) בכל ליבה, צריך להגדיר את הפרמטר include_hlo_info לערך True:

summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)

הפלט כולל מידע נוסף על HLO:

"1": {
    "core_id": {
      "global_core_id": 1,
      "chip_id": 0,
      "core_on_chip": {
        "type": "TPU_CORE_TYPE_TENSOR_CORE",
        "index": 1
      }
    },
    "sequencer_info": [
      {
        "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
        "sequencer_index": 0,
        "pc": 17776,
        "tag": 3,
        "tracemark": 2147483646,
        "program_id": 3230481660274331500,
        "run_id": 81,
        "hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
        "hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
      }
    ],
    "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
    "launch_id": 1394130914,
    "queued_program_info": [
      {
        "run_id": 81,
        "launch_id": 1394130914,
        "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
      }
    ]
  }

מדדים של TPU-Z

הפונקציה get_core_state_summary מחזירה מדדים של TPU-Z בצורה של מילון עם המבנה הבא.

CurrentCoreStateSummary

CurrentCoreStateSummary המילון מספק סיכום מפורט של מצב ליבת TPU ספציפית.

שדה סוג תיאור
core_id מילון מילון TpuCoreIdentifier שמכיל מידע על מזהה ליבת ה-TPU.
sequencer_info רשימת המילונים רשימה של SequencerInfo מילונים, שמתארים את המצב של כל רכיב ליצירת רצפים בליבה.
program_fingerprint בייטים טביעת האצבע של התוכנית שמופעלת בליבה הזו.
launch_id מספר שלם מזהה ההפעלה של התוכנית הנוכחית או האחרונה.
queued_program_info רשימת המילונים רשימה של QueuedProgramInfo מילונים לתוכניות שנמצאות בתור להרצה.
error_message מחרוזת הודעות שגיאה לגבי הליבה הזו.

TpuCoreIdentifier

המילון TpuCoreIdentifier מספק מידע על מזהים של ליבות במערכת TPU.

שדה סוג תיאור
global_core_id מספר שלם המזהה של הליבה.
chip_id מספר שלם המזהה של הצ'יפ שאליו שייכת הליבה.
core_on_chip מילון מילון TpuCoreOnChip שמתאר את הסוג של הליבה ואת האינדקס שלה בשבב.

TpuCoreOnChip

מילון TpuCoreOnChip מכיל מידע על המאפיינים של ליבה מסוימת בתוך צ'יפ מסוים.

שדה סוג תיאור
type מחרוזת סוג ליבת ה-TPU. לדוגמה: TPU_CORE_TYPE_TENSOR_CORE.
index מספר שלם האינדקס של הליבה בשבב.

SequencerInfo

המילון SequencerInfo מכיל מידע על המצב של רכיב sequencer יחיד בליבה.

שדה סוג תיאור
sequencer_type מחרוזת סוג הרצף. לדוגמה: TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER.
sequencer_index מספר שלם האינדקס של הרכיב ליצירת רצף (אם יש כמה רכיבים מאותו סוג).
pc מספר שלם הערך הנוכחי של מונה התוכנית.
program_id מספר שלם המזהה שמשויך למופע ספציפי של תוכנית שמופעלת לביצוע בליבת TPU.
run_id מספר שלם מזהה הריצה שמשויך למופע ספציפי של הפעלת תוכנית בליבת TPU.
hlo_location מחרוזת נתוני מיקום של כלי האופטימיזציה ברמה גבוהה.
hlo_detailed_info מחרוזת מידע מפורט על הכלי לשיפור הביצועים ברמה גבוהה.

QueuedProgramInfo

המילון QueuedProgramInfo מכיל מידע על תוכניות שנמצאות בתור להרצה בליבה.

שדה סוג תיאור
run_id מספר שלם מזהה ההפעלה של התוכנית שהוכנסה לתור.
launch_id מספר שלם מזהה ההשקה של התוכנית שנוספה לתור.
program_fingerprint בייטים טביעת האצבע של התוכנית בתור.

‫TPU-Z עם JAX

אפשר לגשת למדדים של TPU-Z בעומסי עבודה של JAX באמצעות הספרייה libtpu.sdk. סקריפט Python הבא משתמש ב-JAX לחישוב טנסורים עם ביצועים גבוהים, ובמקביל משתמש ב-libtpu SDK בשרשור ברקע כדי לעקוב אחרי המצב והפעילות של חומרת ה-TPU הבסיסית.

צריך לכלול את חבילות Python הבאות:

import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk

הפונקציה monitor_tpu_status משתמשת בשרשור ברקע כדי להציג באופן רציף את סטטוס הפעולה של ליבות ה-TPU, בזמן שהאפליקציה הראשית מבצעת עומס עבודה של JAX. הוא משמש ככלי לאבחון בעיות בזמן אמת.

def monitor_tpu_status():
  """Monitors TPU status in a background thread."""

  while monitoring_active:
    try:
      summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
      if summary and 'core_states' in summary:
        print(summary)
      else:
        print('WARNING: Call returned an empty or invalid summary.')
    except RuntimeError as e:
      print(f'FAIL: Error calling API: {e}')
    except Exception as e:
      print(f'FAIL: Unexpected error in monitor thread: {e}')

    for _ in range(MONITORING_INTERVAL_SECONDS * 2):
      if not monitoring_active:
        break
      time.sleep(0.5)
  print('✅ Monitoring thread stopped.')

הפונקציה transformer_block מיישמת שכבה מלאה של ארכיטקטורת ה-Transformer, שהיא אבן הבניין הבסיסית של מודלים של LLM.

@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
  """A simplified but computationally intensive Transformer block."""
  # Multi-head Self-Attention
  qkv = jnp.dot(x, params['qkv_kernel'])
  q, k, v = jnp.array_split(qkv, 3, axis=-1)

  # Reshape for multi-head attention
  q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)

  # Scaled dot-product attention
  attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
  attention_weights = jax.nn.softmax(attention_scores, axis=-1)
  attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
  attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)

  attention_output = jnp.dot(attention_output, params['o_kernel'])

  # Residual connection and Layer Normalization 1
  h1 = x + attention_output
  h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
  h1_norm = h1_norm / jnp.sqrt(
      jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
  )

  # Feed-Forward Network
  ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
  ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])

  # Residual connection and Layer Normalization 2
  h2 = h1_norm + ffn_output
  h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
  h2_norm = h2_norm / jnp.sqrt(
      jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
  )

  return h2_norm

הפונקציה main מתזמנת את ההגדרה של החישוב ב-JAX, מפעילה את המעקב ברקע של TPU ומריצה את הלולאה הראשית של עומס העבודה.

def main():
  num_devices = jax.device_count()
  print(f"Running on {num_devices} devices.")

  batch_size = 128 * num_devices
  seq_len = 512
  embed_dim = 1024
  ffn_dim = embed_dim * 4

  key = jax.random.PRNGKey(0)

  params = {
      'qkv_kernel': jax.random.normal(
          key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
      ),
      'o_kernel': jax.random.normal(
          key, (embed_dim, embed_dim), dtype=jnp.bfloat16
      ),
      'ffn1_kernel': jax.random.normal(
          key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
      ),
      'ffn2_kernel': jax.random.normal(
          key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
      ),
  }
  input_data = jax.random.normal(
      key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
  )
  input_data = jax.device_put(input_data)
  monitor_thread = threading.Thread(target=monitor_tpu_status)
  monitor_thread.start()
  print("Starting JAX computation loop...")
  start_time = time.time()
  iterations = 0
  while time.time() - start_time < JOB_DURATION_SECONDS:
    result = transformer_block(params, input_data)
    result.block_until_ready()
    iterations += 1
    print(f'  -> Jax iteration {iterations} complete.', end='\r')

  print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")

  global monitoring_active
  monitoring_active = False
  monitor_thread.join()

if __name__ == '__main__':
  main()

פתרון בעיות

בקטע הזה מופיע מידע לפתרון בעיות שיעזור לכם לזהות ולפתור בעיות שאתם עשויים להיתקל בהן במהלך השימוש בספריית המעקב של TPU.

תכונות או מדדים חסרים

אם אתם לא מצליחים לראות חלק מהתכונות או מהמדדים, הסיבה הכי נפוצה לכך היא גרסה לא עדכנית של libtpu. התכונות והמדדים של ספריית המעקב אחרי TPU כלולים בגרסאות libtpu, ובגרסאות ישנות יותר יכול להיות שחסרים מדדים ותכונות חדשים.

בודקים את הגרסה של libtpu שפועלת בסביבה שלכם:

שורת פקודה:

pip show libtpu

‫Python:

import libtpu

print(libtpu.__version__)

אם אתם לא משתמשים בגרסה האחרונה של libtpu, אתם יכולים להשתמש בפקודה הבאה כדי לעדכן את הספרייה:

pip install --upgrade libtpu