שמירה של התקדמות האימון באמצעות Autocheckpoint
בעבר, כשמכונה וירטואלית של TPU דרשה תחזוקה, התהליך התחיל באופן מיידי, בלי להשאיר זמן למשתמשים לבצע פעולות לשמירת ההתקדמות, כמו שמירת נקודת ביקורת. אפשר לראות את זה באיור 1(א).

איור 1. איור של התכונה 'נקודת ביקורת אוטומטית': (א) בלי נקודת ביקורת אוטומטית, התקדמות האימון מנקודת הביקורת האחרונה נעלמת כשמתרחש אירוע תחזוקה. (ב) באמצעות Autocheckpoint, אפשר לשמור את התקדמות האימון מאז נקודת הבדיקה האחרונה, אם צפוי אירוע תחזוקה.
אפשר להשתמש ב-Autocheckpoint (איור 1(ב)) כדי לשמור את ההתקדמות בתהליך האימון. לשם כך, צריך להגדיר את הקוד כך שישמור נקודת ביקורת לא מתוזמנת כשמתרחש אירוע תחזוקה. כשמתרחש אירוע תחזוקה, ההתקדמות מאז נקודת הבדיקה האחרונה נשמרת באופן אוטומטי. התכונה פועלת גם בפרוסות בודדות וגם בפרוסות מרובות.
התכונה Autocheckpoint פועלת עם מסגרות שיכולות לתעד אותות SIGTERM ולשמור נקודת ביקורת. בין ה-frameworks הנתמכים:
שימוש ב-Autocheckpoint
התכונה 'נקודת עצירה אוטומטית' מושבתת כברירת מחדל. כשיוצרים TPU או שולחים בקשה למשאב בתור, אפשר להפעיל את השמירה האוטומטית של נקודות ביקורת על ידי הוספת הדגל --autocheckpoint-enabled כשמקצים את ה-TPU.
כשהתכונה מופעלת, Cloud TPU מבצע את השלבים הבאים ברגע שהוא מקבל הודעה על אירוע תחזוקה:
- לכידת אות SIGTERM שנשלח לתהליך באמצעות מכשיר TPU
- מחכים עד שהתהליך יסתיים או עד שיחלפו 5 דקות, מה שקורה קודם.
- ביצוע תחזוקה בפרוסות שהושפעו
התשתית שבה נעשה שימוש ב-Autocheckpoint לא תלויה במסגרת ML. כל מסגרת למידת מכונה יכולה לתמוך ב-Autocheckpoint אם היא יכולה ללכוד את אות ה-SIGTERM ולהתחיל תהליך של יצירת נקודת ביקורת.
בקוד האפליקציה, צריך להפעיל את היכולות של Autocheckpoint שמסופקות על ידי מסגרת ה-ML. לדוגמה, ב-Pax, צריך להפעיל את האפשרויות בשורת הפקודה כשמפעילים את ההדרכה. מידע נוסף זמין במאמר המדריך למתחילים לשימוש ב-Autocheckpoint עם Pax. מאחורי הקלעים, המסגרות שומרות נקודת ביקורת לא מתוזמנת כשמתקבל אות SIGTERM, ומכונת ה-TPU הווירטואלית המושפעת עוברת תחזוקה כשה-TPU כבר לא בשימוש.
מדריך למתחילים: יצירת נקודות ביקורת אוטומטיות באמצעות MaxText
MaxText היא ספרייה של מודלי שפה גדולים (LLM) בקוד פתוח, עם ביצועים גבוהים, יכולת התאמה לעומס (autoscaling) שרירותית ויישום הפניה שנבדק היטב, שנכתב ב-Python/JAX טהור ומיועד ל-Cloud TPU. MaxText מכיל את כל ההגדרות הנדרשות לשימוש בתכונה Autocheckpoint.
בקובץ MaxText READMEמתוארות שתי דרכים להפעלת MaxText בהיקף גדול:
- שימוש ב-
multihost_runner.py, מומלץ לניסויים - שימוש ב-
multihost_job.py, מומלץ לשימוש בסביבת ייצור
כשמשתמשים ב-multihost_runner.py, צריך להפעיל את התכונה 'שמירת נקודת ביקורת אוטומטית' על ידי הגדרת הדגל autocheckpoint-enabled כשמקצים את המשאב בתור.
כשמשתמשים ב-multihost_job.py, מפעילים את התכונה 'שמירת נקודת ביקורת אוטומטית' על ידי ציון הדגל ENABLE_AUTOCHECKPOINT=true בשורת הפקודה כשמפעילים את העבודה.
מדריך למתחילים: יצירת נקודת ביקורת אוטומטית באמצעות Pax בפרוסה אחת
בקטע הזה מובאת דוגמה להגדרה ולשימוש ב-Autocheckpoint עם Pax בפלח יחיד. עם ההגדרה המתאימה:
- נקודת ביקורת תישמר כשמתרחש אירוע תחזוקה.
- Cloud TPU יבצע תחזוקה במכונות ה-TPU הווירטואליות המושפעות אחרי שמירת נקודת הבדיקה.
- אחרי ש-Cloud TPU משלים את התחזוקה, אפשר להשתמש במכונת ה-TPU הווירטואלית כרגיל.
משתמשים בדגל
autocheckpoint-enabledכשיוצרים את המכונה הווירטואלית של TPU או כשמבקשים משאב בתור.לדוגמה:
מגדירים משתני סביבה:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=zone-you-want-to-use export ACCELERATOR_TYPE=your-accelerator-type export RUNTIME_VERSION=tpu-ubuntu2204-base
תיאורים של משתני סביבה
PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.-
TPU_NAME: השם של ה-TPU. -
ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU. -
ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU. -
RUNTIME_VERSION: גרסת התוכנה של Cloud TPU.
מגדירים את מזהה הפרויקט ואת האזור בהגדרות הפעילות:
gcloud config set project $PROJECT_ID gcloud config set compute/zone $ZONE
יוצרים TPU:
gcloud alpha compute tpus tpu-vm create $TPU_NAME \ --accelerator-type $ACCELERATOR_TYPE \ --version $RUNTIME_VERSION \ --autocheckpoint-enabled
מתחברים ל-TPU באמצעות SSH:
gcloud compute tpus tpu-vm ssh $TPU_NAMEהתקנת Pax על פרוסה אחת
התכונה Autocheckpoint פועלת בגרסאות Pax 1.1.0 ואילך. ב-TPU VM, מתקינים את
jax[tpu]ואת הגרסה האחרונה שלpaxml:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
מגדירים את מודל
LmCloudSpmd2B. לפני שמריצים את סקריפט ההדרכה, משנים אתICI_MESH_SHAPEל-[1, 8, 1]:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 8, 1]
מפעילים את האימון עם ההגדרה המתאימה.
בדוגמה הבאה מוצגות ההגדרות של מודל
LmCloudSpmd2Bלשמירת נקודות ביקורת שהופעלו על ידי Autocheckpoint בקטגוריה של Cloud Storage. מחליפים את your-storage-bucket בשם של קטגוריה קיימת, או יוצרים קטגוריה חדשה.export JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py \ --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
שימו לב לשני הדגלים שמועברים לפקודה:
-
jax_fully_async_checkpoint: אם הדגל הזה מופעל, נעשה שימוש ב-orbax.checkpoint.AsyncCheckpointer. הכיתהAsyncCheckpointerשומרת באופן אוטומטי נקודת ביקורת כשסקריפט האימון מקבל אות SIGTERM. -
exit_after_ondemand_checkpoint: אם הדגל הזה מופעל, תהליך ה-TPU יוצא אחרי שנקודת הבדיקה האוטומטית נשמרת בהצלחה, מה שגורם לביצוע התחזוקה באופן מיידי. אם לא משתמשים בדגל הזה, האימון יימשך אחרי שמירת נקודת הבדיקה, ו-Cloud TPU יחכה להתרחשות זמן קצוב לתפוגה (5 דקות) לפני ביצוע התחזוקה הנדרשת.
-
נקודת ביקורת אוטומטית באמצעות Orbax
התכונה 'בדיקה אוטומטית' לא מוגבלת ל-MaxText או ל-Pax. כל מסגרת שיכולה לתעד את אות ה-SIGTERM ולהתחיל תהליך של יצירת נקודת ביקורת פועלת עם התשתית שמסופקת על ידי Autocheckpoint. Orbax, מרחב שמות שמספק ספריות שירות נפוצות למשתמשי JAX, מספק את היכולות האלה.
כמו שמוסבר במסמכי Orbax, היכולות האלה מופעלות כברירת מחדל למשתמשי orbax.checkpoint.CheckpointManager. השיטה save שמופעלת אחרי כל שלב בודקת באופן אוטומטי אם צפוי אירוע תחזוקה. אם כן, היא שומרת נקודת ביקורת גם אם מספר השלב הוא לא כפולה של save_interval_steps.
במאמרי העזרה של GitHub מוסבר גם איך לשנות את קוד המשתמש כדי שהאימון יסתיים אחרי שמירת נקודת ביקורת אוטומטית.