העברת עומסי עבודה (workloads) של JAX ל-Pathways

בגלל האופי המבוזר של JAX עם Pathways, יכול להיות שחלק מהפעולות לא יתבצעו בצורה טובה בגלל תקורה של תקשורת. ב-Pathways יש תכונות כמו שליחה אסינכרונית שמצמצמות את התקורה הזו, אבל יש כמה דברים שחשוב לדעת כשמעבירים עומסי עבודה של JAX ל-Pathways או כשמגדילים את עומס העבודה של JAX עם Pathways למספר גדול של מאיצים.

לפני שמתחילים

חשוב לוודא שיש לכם:

אינדקס תהליכים

ב-JAX עם Pathways, כל המכשירים באשכול Pathways נחשבים למקומיים. כך ניהול המכשירים פשוט יותר, ומאפשר ל-JAX להשתמש בכל המשאבים הזמינים. בפועל, המשמעות היא:

  • הערך של jax.process_index() הוא תמיד 0 בכל המכשירים.
  • jax.devices() ו-jax.local_devices() מחזירים את כל מכשירי ה-TPU בכל המשימה.

סוג החומרה ומיקום השרתים

כדי להשיג את הביצועים הכי טובים, צריך למקם את כל רכיבי Pathways ואת עבודת המשתמש באותו אזור ענן של Google Cloud . משתמשים במעבד גדול כמו IFRT proxy ו-resource manager. מומלץ להשתמש לפחות ב-n2-standard-64 ייעודי, עם 64 ליבות וירטואליות וזיכרון של 256GB.

PathwaysUtils

Pathways-utils הוא מאגר GitHub מבוסס-Python שמספק כלי עזר וכלים חיוניים שמאפשרים לייעל את הפריסה וההפעלה של עומסי עבודה של JAX בארכיטקטורה של Pathways on Cloud. החבילה הזו מטפלת בהתאמות הנדרשות לסביבת הענן, ומאפשרת למפתחי JAX להתמקד בתהליכי העבודה העיקריים של למידת מכונה עם מינימום הגדרות ספציפיות לפלטפורמה. באופן ספציפי, הוא מציע:

  • עורף קצה (backend) של JAX מסוג proxy: עורף הקצה המותאם אישית הזה מאפשר לאפליקציית JAX להשתמש בתשתית של Pathways על ידי הגדרת משתנה הסביבה JAX_PLATFORMS=proxy.
  • כלי פרופיל משולבים: יכולות פרופיל שמאפשרות להבין את הביצועים של האפליקציה. באמצעות ממשקי API סטנדרטיים של JAX ליצירת פרופילים, כמו jax.profiler.start_trace ו-jax.profiler.start_server, אתם יכולים ליצור פרופילים לא רק של קוד JAX, אלא גם של רכיבי Pathways הבסיסיים, וכך לקבל תמונה הוליסטית של הביצוע בסביבת הענן.
  • ‫Distributed Checkpointing with Orbax: ‏ handler מותאם אישית של Orbax checkpoint שמאפשר לכם להשתמש ב-distributed checkpoints ולשחזר את ה-checkpoints שלכם כשאתם משתמשים בספריית Orbax בסביבת Pathways. השילוב הזה נועד לפעול בלי לדרוש שינויים בקוד הקיים של Orbax checkpointing, כל עוד הוא מייבא את pathwaysutils.
  • Elastic Training Primitives: מספק פרימיטיבים בסיסיים של אימון גמיש שאפשר להשתמש בהם כדי ליצור תהליכי עבודה חזקים וניתנים להרחבה של אימון באמצעות Pathways. הפרימיטיבים האלה מאפשרים לעבודות האימון שלכם להסתגל באופן דינמי לשינויים במשאבים הזמינים, וכך לשפר את היעילות והעמידות בסביבות ענן.

Checkpointing

Orbax נבדק באופן יסודי עם Pathways לצורך יצירת נקודות ביקורת מבוזרות ושחזור באמצעות Cloud Storage. כשמגדירים את משתנה הסביבה ENABLE_PATHWAYS_PERSISTENCE=1 וקוראים ל-import pathwaysutils; pathwaysutils.initialize() ב-train.py, נרשם ArrayHandler מותאם אישית שמטפל ביעילות בפעולות של נקודות ביקורת דרך ה-proxy של IFRT, וכך מאפשר לעובדי Pathways במאיצים לשמור ולשחזר נתונים ישירות.

Python שמוצב באותו מיקום

Colocated Python הוא API של JAX בקוד פתוח שמאפשר להריץ קוד Python שצוין על ידי המשתמש ישירות במארחי TPU או GPU, וזה פשוט יותר ב-JAX עם כמה בקרי TPU. כך אפשר לבצע משימות שדורשות יותר כוח מחשוב, כמו טעינת נתונים ויצירת נקודות ביקורת, בלי להעביר נתונים בין הלקוח למכונות TPU. כדי להגדיר את אשכול Pathways להרצת Python JAX API במיקום משותף, פועלים לפי ההוראות בקובץ ה-README של Python במיקום משותף. בהוראות האלה מוסבר איך להפעיל תהליך Python sidecar שמוצב באותו מיקום לצד תהליכי העבודה של Pathways.

טעינת נתונים

במהלך האימון אנחנו טוענים שוב ושוב אצוות ממערך נתונים כדי להזין אותן למודל. כדי למנוע מצב שבו מאיצים לא מקבלים מספיק עבודה, חשוב להשתמש בטוען נתונים אסינכרוני יעיל שמחלק את האצווה בין המארחים. כשמריצים אימון באמצעות Pathways, טוען הנתונים פועל במכונה וירטואלית של CPU (בניגוד למכונה וירטואלית של TPU שמשמשת בהגדרות מרובות בקרים) ושולח נתונים למכונות וירטואליות של TPU. הפעולה הזו גורמת לזמן אחזור ארוך יותר בקריאת הנתונים, אבל אפשר לצמצם את ההשפעה שלה על ידי קריאה מראש של X מספר אצוות במארח המעבד (CPU) ושליחת הנתונים שנקראו באופן אסינכרוני ל-TPU. הפתרון הזה מספיק כשמריצים אותו בהיקף קטן עד בינוני.

כדי להשיג ביצועים אופטימליים בהיקף גדול, מומלץ מאוד למקם את צינור הנתונים של הקלט במיקום משותף באמצעות Python במיקום משותף כדי להפעיל את צינור הנתונים ישירות במאיצים. כך נמנע צוואר בקבוק במעבד, ונעשה שימוש בחיבורים המהירים של TPU להעברת נתונים.

אפשר למצוא הטמעה לדוגמה של העברת צינור קלט שמבוסס על TFDS בהטמעה של RemoteIterator ב-multihost_dataloading.py. ההטמעה הזו פועלת גם ב-JAX עם כמה בקרי Pathways וגם ב-Pathways באופן מבוזר באמצעות Python JAX API שמוצב באותו מיקום.

ניהול גרסאות של Jax

הגרסאות של Pathways קשורות באופן הדוק לגרסאות של JAX כדי להבטיח תאימות ויציבות. כדי למנוע בעיות אפשריות, חשוב לוודא שפריטי ה-artifact של Pathways וגרסת JAX תואמים. בכל גרסה של Pathways מצוינות בבירור גרסאות JAX התואמות באמצעות תג מהצורה jax-<version>.

מטמון של הידור

מטמון קומפילציה מתמשך של Pathways הוא תכונה שמאפשרת לשרתים של Pathways לאחסן קובצי הפעלה של XLA שעברו קומפילציה במיקום מתמשך, כמו Cloud Storage, כדי למנוע קומפילציה מיותרת. התכונה הזו מופעלת כברירת מחדל. המיקום של הזיכרון מטמון מועבר כדגל --gcs_scratch_location למנהל המשאבים ולמאגרי העובדים של Pathways. כדי לצמצם את עלויות האחסון המשויכות, המטמון מצרף מדיניות מחזור חיים למיקום ב-Cloud Storage. יש מגבלה של 50 כללי מדיניות לכל קטגוריה של Cloud Storage. לכן, מומלץ להשתמש במיקום משותף ב-Cloud Storage בכל עומסי העבודה.

המטמון הזה דומה למטמון הקומפילציה של JAX, שמושבת על ידי pathwaysutils.initialize() עבור עומסי עבודה של Pathways.

נדרשות ההרשאות הבאות ב-Cloud Storage כדי להשתמש במטמון של הקומפילציה:

  • storage.buckets.get: כדי לאחזר מטא-נתונים של קטגוריה.
  • storage.buckets.update: חיוני ל-Pathways כדי להגדיר מדיניות מחזור חיים של אובייקטים לאכיפת TTL לצורך פינוי מטמון.
  • storage.objects.list: כדי להציג רשימה של אובייקטים קיימים במטמון בתוך הקטגוריה.
  • storage.objects.create: כדי לכתוב קובצי הפעלה חדשים שעברו קומפילציה למטמון.
  • storage.objects.get: כדי לקרוא קובצי הפעלה ששמורים במטמון מהקטגוריה.

יצירת פרופילים

אפשר להשתמש בכלי ליצירת פרופילים של JAX כדי ליצור עקבות של תוכנית JAX. יש שתי דרכים נפוצות שנתמכות ב-Pathways:

  • פרוגרמטי
    • איך לוכדים פרופילים באופן פרוגרמטי מקוד JAX
  • באופן ידני
    • תיעוד פרופילים לפי דרישה אחרי הפעלת שרת הפרופילים מקוד JAX

בשני המקרים, הפרופילים נכתבים לקטגוריה של Cloud Storage. ייווצרו כמה קובצי מעקב בקטגוריה של Cloud Storage, יכול להיות שמתחת לתיקיות שונות של חותמות זמן, למשל:

  • תהליך Python הראשי שהפעיל את המעקב (בדרך כלל מכונת ה-VM של ה-Notebook): <jax-client-vm-name>.xplane.pb
  • שרת proxy של Pathways IFRT: client.<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • מנהל המשאבים של תוכניות הלימודים: server.*<var>PATHWAYS_HEAD_NODE_NAME</var>.xplane.pb
  • קובצי שירות של Pathways: server.*<tpu-node-name>.xplane.pb

אפשר לנתח את קובצי העקבות האלה באמצעות TensorBoard על ידי הרצת הפקודה הבאה. מידע נוסף על TensorBoard ועל כל כלי הפרופילים שלה זמין במאמר אופטימיזציה של הביצועים של TensorFlow באמצעות כלי הפרופילים.

# verify trace files are present
gsutil ls -l -r gs://BUCKET/PREFIX

# View on tensorboard
tensorboard --logdir=gs://BUCKET/PREFIX

מחליפים את מה שכתוב בשדות הבאים:

  • BUCKET : קטגוריה של Cloud Storage לאחסון קובצי המעקב
  • PREFIX: נתיב בתוך הקטגוריה של Cloud Storage לאחסון קובצי ה-Trace

לכידת פרופילים פרוגרמטית

מצלמים פרופיל מתוך הקוד. הפרופילים נשמרים בתיקייה gs://<bucket>/<prefix> עם חותמת זמן

import jax
import pathwaysutils

pathwaysutils.initialize()

jax.profiler.start_trace("gs://BUCKET/PREFIX")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace()

לכידת פרופיל באופן ידני

כדי ללכוד את פרטי הפרופיל באופן ידני, צריך להפעיל את שרת הפרופילר מקוד Python:

import jax
import pathwaysutils

pathwaysutils.initialize()
jax.profiler.start_server(jax_profiler_port)

# Your JAX code
jax.profiler.stop_server() # this is functionally a no-op

בזמן ששרת הפרופיל פועל, אפשר לצלם פרופיל ולייצא את הנתונים למיקום היעד ב-Cloud Storage:

export DURATION_IN_SECS=6
export OUTPUT_DIR=gs://<var>BUCKET</var>/<var>PREFIX</var>

curl -d "{\"duration_ms\":\"${DURATION_IN_SECS} * 1000 }}\", \"repository_path\":\"${OUTPUT_DIR}\"}" -H "Content-Type: application/json" -X POST http://localhost:<jax_profiler_port>

אפשר למצוא מידע על תזמון של שיטות לקוח פרוקסי של IFRT כמו Compile ו-Execute במעקב של התוכנית. האירועים האלה, שמפרטים את האינטראקציות עם שרת ה-proxy של IFRT gRPC במהלך ההידור והביצוע, מופיעים בשרשור שנקרא GrpcClientSessionUserFuturesWorkQueue. בדיקת השרשור הזה בנתוני המעקב יכולה לספק תובנות לגבי הביצועים של הפעולות האלה.

סימוני XLA

כשמשתמשים ב-Pathways, צריך להגדיר את דגלי ה-XLA במאגר pathways-proxy. אפשר לעשות את זה באמצעות XPK או PathwaysJob API.

כשמשתמשים ב-XPK, מגדירים דגלי XLA כמו בדוגמה הבאה:

--custom-pathways-proxy-server-args="--xla_flag_1=value1 --xla_flag_2=value2"

כשמשתמשים ב-PathwaysJob API, מגדירים דגלי XLA כמו בדוגמה הבאה:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customFlags:
    - --xla_flag_1=value1
    - --xla_flag_2=value2

מחליפים את מה שכתוב בשדות הבאים:

  • USER : שם המשתמש שלך Google Cloud
  • value[n]: הדגלים של XLA שרוצים להגדיר

HLO Dump

כדי לבצע ניתוח מעמיק של קלט של High Level Optimizer ‏ (HLO) שמועבר לקומפיילר XLA, אפשר להגדיר את Pathways כך שיבצע dump של ה-HLO למיקום ספציפי ב-Cloud Storage באופן הבא:

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  customComponents:
  - componentType: proxy_server
    customEnv:
    - name: XLA_FLAGS
      value: "--xla_dump_to=gs://your-bucket/your-desired-prefix/"

המאמרים הבאים