שמירה של התקדמות האימון באמצעות Autocheckpoint

בעבר, כשמכונה וירטואלית של TPU דרשה תחזוקה, התהליך התחיל באופן מיידי, בלי להשאיר זמן למשתמשים לבצע פעולות לשמירת ההתקדמות, כמו שמירת נקודת ביקורת. אפשר לראות את זה באיור 1(א).

תרשים שמראה את ההשפעה של תחזוקת המארח עם ובלי שמירת נקודות ביקורת אוטומטית

איור 1. איור של התכונה 'נקודת ביקורת אוטומטית': (א) בלי נקודת ביקורת אוטומטית, התקדמות האימון מנקודת הביקורת האחרונה נעלמת כשמתרחש אירוע תחזוקה. ‫(ב) באמצעות Autocheckpoint, אפשר לשמור את התקדמות האימון מאז נקודת הבדיקה האחרונה, אם צפוי אירוע תחזוקה.

אפשר להשתמש ב-Autocheckpoint (איור 1(ב)) כדי לשמור את ההתקדמות בתהליך האימון. לשם כך, צריך להגדיר את הקוד כך שישמור נקודת ביקורת לא מתוזמנת כשמתרחש אירוע תחזוקה. כשמתרחש אירוע תחזוקה, ההתקדמות מאז נקודת הבדיקה האחרונה נשמרת באופן אוטומטי. התכונה פועלת גם בפרוסות בודדות וגם בפרוסות מרובות.

התכונה Autocheckpoint פועלת עם מסגרות שיכולות לתעד אותות SIGTERM ולשמור נקודת ביקורת. בין ה-frameworks הנתמכים:

שימוש ב-Autocheckpoint

התכונה 'נקודת עצירה אוטומטית' מושבתת כברירת מחדל. כשיוצרים TPU או שולחים בקשה למשאב בתור, אפשר להפעיל את השמירה האוטומטית של נקודות ביקורת על ידי הוספת הדגל --autocheckpoint-enabled כשמקצים את ה-TPU. כשהתכונה מופעלת, Cloud TPU מבצע את השלבים הבאים ברגע שהוא מקבל הודעה על אירוע תחזוקה:

  1. לכידת אות SIGTERM שנשלח לתהליך באמצעות מכשיר TPU
  2. מחכים עד שהתהליך יסתיים או עד שיחלפו 5 דקות, מה שקורה קודם.
  3. ביצוע תחזוקה בפרוסות שהושפעו

התשתית שבה נעשה שימוש ב-Autocheckpoint לא תלויה במסגרת ML. כל מסגרת למידת מכונה יכולה לתמוך ב-Autocheckpoint אם היא יכולה ללכוד את אות ה-SIGTERM ולהתחיל תהליך של יצירת נקודת ביקורת.

בקוד האפליקציה, צריך להפעיל את היכולות של Autocheckpoint שמסופקות על ידי מסגרת ה-ML. לדוגמה, ב-Pax, צריך להפעיל את האפשרויות בשורת הפקודה כשמפעילים את ההדרכה. מידע נוסף זמין במאמר המדריך למתחילים לשימוש ב-Autocheckpoint עם Pax. מאחורי הקלעים, המסגרות שומרות נקודת ביקורת לא מתוזמנת כשמתקבל אות SIGTERM, ומכונת ה-TPU הווירטואלית המושפעת עוברת תחזוקה כשה-TPU כבר לא בשימוש.

מדריך למתחילים: יצירת נקודות ביקורת אוטומטיות באמצעות MaxText

MaxText היא ספרייה של מודלי שפה גדולים (LLM) בקוד פתוח, עם ביצועים גבוהים, יכולת התאמה לעומס (autoscaling) שרירותית ויישום הפניה שנבדק היטב, שנכתב ב-Python/JAX טהור ומיועד ל-Cloud TPU. ‫MaxText מכיל את כל ההגדרות הנדרשות לשימוש בתכונה Autocheckpoint.

בקובץ MaxText READMEמתוארות שתי דרכים להפעלת MaxText בהיקף גדול:

כשמשתמשים ב-multihost_runner.py, צריך להפעיל את התכונה 'שמירת נקודת ביקורת אוטומטית' על ידי הגדרת הדגל autocheckpoint-enabled כשמקצים את המשאב בתור.

כשמשתמשים ב-multihost_job.py, מפעילים את התכונה 'שמירת נקודת ביקורת אוטומטית' על ידי ציון הדגל ENABLE_AUTOCHECKPOINT=true בשורת הפקודה כשמפעילים את העבודה.

מדריך למתחילים: יצירת נקודת ביקורת אוטומטית באמצעות Pax בפרוסה אחת

בקטע הזה מובאת דוגמה להגדרה ולשימוש ב-Autocheckpoint עם Pax בפלח יחיד. עם ההגדרה המתאימה:

  • נקודת ביקורת תישמר כשמתרחש אירוע תחזוקה.
  • ‫Cloud TPU יבצע תחזוקה במכונות ה-TPU הווירטואליות המושפעות אחרי שמירת נקודת הבדיקה.
  • אחרי ש-Cloud TPU משלים את התחזוקה, אפשר להשתמש במכונת ה-TPU הווירטואלית כרגיל.
  1. משתמשים בדגל autocheckpoint-enabled כשיוצרים את המכונה הווירטואלית של TPU או כשמבקשים משאב בתור.

    לדוגמה:

    1. מגדירים משתני סביבה:

      export PROJECT_ID=your-project-id
      export TPU_NAME=your-tpu-name
      export ZONE=zone-you-want-to-use
      export ACCELERATOR_TYPE=your-accelerator-type
      export RUNTIME_VERSION=tpu-ubuntu2204-base

      תיאורים של משתני סביבה

      • PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.
      • TPU_NAME: השם של ה-TPU.
      • ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU.
      • ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU.
      • RUNTIME_VERSION: גרסת התוכנה של Cloud TPU.

    2. מגדירים את מזהה הפרויקט ואת האזור בהגדרות הפעילות:

      gcloud config set project $PROJECT_ID
      gcloud config set compute/zone $ZONE
    3. יוצרים TPU:

      gcloud alpha compute tpus tpu-vm create $TPU_NAME \
          --accelerator-type $ACCELERATOR_TYPE \
          --version $RUNTIME_VERSION \
          --autocheckpoint-enabled
  2. מתחברים ל-TPU באמצעות SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME
    
  3. התקנת Pax על פרוסה אחת

    התכונה Autocheckpoint פועלת בגרסאות Pax 1.1.0 ואילך. ב-TPU VM, מתקינים את jax[tpu] ואת הגרסה האחרונה של paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  4. מגדירים את מודל LmCloudSpmd2B. לפני שמריצים את סקריפט ההדרכה, משנים את ICI_MESH_SHAPE ל-[1, 8, 1]:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 8, 1]
  5. מפעילים את האימון עם ההגדרה המתאימה.

    בדוגמה הבאה מוצגות ההגדרות של מודל LmCloudSpmd2B לשמירת נקודות ביקורת שהופעלו על ידי Autocheckpoint בקטגוריה של Cloud Storage. מחליפים את your-storage-bucket בשם של קטגוריה קיימת, או יוצרים קטגוריה חדשה.

    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    שימו לב לשני הדגלים שמועברים לפקודה:

    • jax_fully_async_checkpoint: אם הדגל הזה מופעל, נעשה שימוש ב-orbax.checkpoint.AsyncCheckpointer. הכיתה AsyncCheckpointer שומרת באופן אוטומטי נקודת ביקורת כשסקריפט האימון מקבל אות SIGTERM.
    • exit_after_ondemand_checkpoint: אם הדגל הזה מופעל, תהליך ה-TPU יוצא אחרי שנקודת הבדיקה האוטומטית נשמרת בהצלחה, מה שגורם לביצוע התחזוקה באופן מיידי. אם לא משתמשים בדגל הזה, האימון יימשך אחרי שמירת נקודת הבדיקה, ו-Cloud TPU יחכה להתרחשות זמן קצוב לתפוגה (5 דקות) לפני ביצוע התחזוקה הנדרשת.

נקודת ביקורת אוטומטית באמצעות Orbax

התכונה 'בדיקה אוטומטית' לא מוגבלת ל-MaxText או ל-Pax. כל מסגרת שיכולה לתעד את אות ה-SIGTERM ולהתחיל תהליך של יצירת נקודת ביקורת פועלת עם התשתית שמסופקת על ידי Autocheckpoint. ‫Orbax, מרחב שמות שמספק ספריות שירות נפוצות למשתמשי JAX, מספק את היכולות האלה.

כמו שמוסבר במסמכי Orbax, היכולות האלה מופעלות כברירת מחדל למשתמשי orbax.checkpoint.CheckpointManager. השיטה save שמופעלת אחרי כל שלב בודקת באופן אוטומטי אם צפוי אירוע תחזוקה. אם כן, היא שומרת נקודת ביקורת גם אם מספר השלב הוא לא כפולה של save_interval_steps. במאמרי העזרה של GitHub מוסבר גם איך לשנות את קוד המשתמש כדי שהאימון יסתיים אחרי שמירת נקודת ביקורת אוטומטית.