התאמת עומסי עבודה של למידת מכונה באמצעות Ray

במסמך הזה מפורטות הוראות להפעלת עומסי עבודה של למידת מכונה (ML) באמצעות Ray ו-JAX ב-TPU. יש שני מצבים שונים לשימוש ב-TPU עם Ray: מצב ממוקד-מכשיר (PyTorch/XLA) ו מצב ממוקד-מארח (JAX).

המסמך הזה מיועד למשתמשים שכבר הגדירו סביבת TPU. מידע נוסף זמין במקורות המידע הבאים:

מצב שמתמקד במכשיר (PyTorch/XLA)

במצב שמתמקד במכשיר, נשמרים רוב המאפיינים של סגנון התכנות הקלאסי של PyTorch. במצב הזה, מוסיפים סוג חדש של מכשיר XLA, שפועל כמו כל מכשיר PyTorch אחר. כל תהליך בודד מתקשר עם מכשיר XLA אחד.

המצב הזה מתאים במיוחד אם אתם כבר מכירים את PyTorch עם GPUs ורוצים להשתמש בהפשטות דומות של קידוד.

בקטעים הבאים מוסבר איך להריץ עומס עבודה של PyTorch/XLA במכשיר אחד או יותר בלי להשתמש ב-Ray, ואז איך להריץ את אותו עומס עבודה בכמה מארחים באמצעות Ray.

יצירת TPU

  1. יוצרים משתני סביבה לפרמטרים של יצירת TPU.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-8
    export RUNTIME_VERSION=v2-alpha-tpuv5

    תיאורים של משתני סביבה

    • PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.
    • TPU_NAME: השם של ה-TPU.
    • ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU.
    • ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU.
    • RUNTIME_VERSION: גרסת התוכנה של Cloud TPU.

  2. כדי ליצור מכונה וירטואלית של TPU מדגם v5p עם 8 ליבות, משתמשים בפקודה הבאה:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. מתחברים למכונת ה-TPU הווירטואלית באמצעות הפקודה הבאה:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

אם אתם משתמשים ב-GKE, תוכלו לקרוא את המדריך KubeRay on GKE כדי לקבל מידע על הגדרות.

דרישות התקנה

מריצים את הפקודות הבאות במכונת ה-VM של TPU כדי להתקין את יחסי התלות הנדרשים:

  1. שומרים את הטקסט הבא בקובץ. לדוגמה, requirements.txt.

    --find-links https://storage.googleapis.com/libtpu-releases/index.html
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html
    torch~=2.6.0
    torch_xla[tpu]~=2.6.0
    ray[default]==2.40.0
    
  2. כדי להתקין את יחסי התלות הנדרשים, מריצים את הפקודה:

    pip install -r requirements.txt
    

אם אתם מריצים את עומס העבודה ב-GKE, מומלץ ליצור Dockerfile שמתקין את התלות הנדרשת. לדוגמה, אפשר לעיין במאמר הרצת עומס העבודה בצמתי חלוקה של TPU במסמכי GKE.

הרצת עומס עבודה של PyTorch/XLA במכשיר יחיד

בדוגמה הבאה מוסבר איך ליצור טנזור XLA במכשיר יחיד, שהוא שבב TPU. זה דומה לאופן שבו PyTorch מטפל בסוגים אחרים של מכשירים.

  1. שומרים את קטע הקוד הבא בקובץ. לדוגמה, workload.py.

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    t = torch.randn(2, 2, device=xm.xla_device())
    print(t.device)
    print(t)
    

    הצהרת הייבוא import torch_xla מאתחלת את PyTorch/XLA, והפונקציה xm.xla_device() מחזירה את מכשיר ה-XLA הנוכחי, שבב TPU.

  2. מגדירים את משתנה הסביבה PJRT_DEVICE ל-TPU.

    export PJRT_DEVICE=TPU
    
  3. מריצים את הסקריפט.

    python workload.py
    

    הפלט אמור להיראות כך: מוודאים שהפלט מציין שמכשיר ה-XLA נמצא.

    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

הפעלת PyTorch/XLA בכמה מכשירים

  1. מעדכנים את קטע הקוד מהקטע הקודם כדי להפעיל אותו בכמה מכשירים.

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    def _mp_fn(index):
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    if __name__ == '__main__':
        torch_xla.launch(_mp_fn, args=())
    
  2. מריצים את הסקריפט.

    python workload.py
    

    אם מריצים את קטע הקוד ב-TPU v5p-8, הפלט אמור להיראות כך:

    xla:0
    xla:0
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    

הפונקציה torch_xla.launch() מקבלת שני ארגומנטים: פונקציה ורשימה של פרמטרים. היא יוצרת תהליך לכל מכשיר XLA זמין ומפעילה את הפונקציה שצוינה בארגומנטים. בדוגמה הזו, יש 4 מכשירי TPU זמינים, ולכן torch_xla.launch() יוצר 4 תהליכים ומפעיל את _mp_fn() בכל מכשיר. לכל תהליך יש גישה רק למכשיר אחד, ולכן לכל מכשיר יש את האינדקס 0, והערך xla:0 מודפס לכל התהליכים.

הפעלת PyTorch/XLA בכמה מארחים באמצעות Ray

בקטעים הבאים מוצג איך להריץ את אותו קטע קוד על פרוסת TPU גדולה יותר עם כמה מארחים. מידע נוסף על ארכיטקטורת TPU מרובת מארחים זמין במאמר ארכיטקטורת מערכת.

בדוגמה הזו, הגדרתם את Ray באופן ידני. אם אתם כבר יודעים איך להגדיר את Ray, אתם יכולים לדלג לקטע האחרון, הפעלת עומס עבודה של Ray. למידע נוסף על הגדרת Ray לסביבת ייצור, אפשר לעיין במקורות המידע הבאים:

יצירת מכונת TPU וירטואלית עם כמה מארחים

  1. יוצרים משתני סביבה לפרמטרים של יצירת TPU.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-16
    export RUNTIME_VERSION=v2-alpha-tpuv5

    תיאורים של משתני סביבה

    • PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.
    • TPU_NAME: השם של ה-TPU.
    • ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU.
    • ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU.
    • RUNTIME_VERSION: גרסת התוכנה של Cloud TPU.

  2. כדי ליצור TPU v5p עם כמה מארחים (v5p-16, עם 4 שבבי TPU בכל מארח) מריצים את הפקודה הבאה:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE \
       --version=$RUNTIME_VERSION

הגדרת Ray

ל-TPU v5p-16 יש 2 מארחי TPU, שלכל אחד מהם יש 4 שבבי TPU. בדוגמה הזו, תפעילו את צומת הראש של Ray במארח אחד ותוסיפו את המארח השני כצומת עובד לאשכול Ray.

  1. מתחברים למארח הראשון באמצעות SSH.

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=0
  2. מתקינים יחסי תלות עם אותו קובץ דרישות כמו בקטע הדרישות להתקנה.

    pip install -r requirements.txt
    
  3. מתחילים את התהליך של Ray.

    ray start --head --port=6379
    

    הפלט אמור להיראות כך:

    Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y
    Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.
    
    Local node IP: 10.130.0.76
    
    --------------------
    Ray runtime started.
    --------------------
    
    Next steps
    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    
    To connect to this Ray cluster:
        import ray
        ray.init()
    
    To terminate the Ray runtime, run
        ray stop
    
    To view the status of the cluster, use
        ray status
    

    מארח ה-TPU הזה הוא עכשיו צומת ה-head של Ray. רושמים את השורות שבהן מוסבר איך להוסיף עוד צומת לאשכול Ray, בדומה לשורות הבאות:

    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    

    תשתמשו בפקודה הזו בשלב מאוחר יותר.

  4. בודקים את הסטטוס של אשכול Ray:

    ray status
    

    הפלט אמור להיראות כך:

    ======== Autoscaler status: 2025-01-14 22:03:39.385610 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/208.0 CPU
    0.0/4.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/268.44GiB memory
    0B/119.04GiB object_store_memory
    0.0/1.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    האשכול מכיל רק 4 יחידות TPU ‏ (0.0/4.0 TPU) כי עד עכשיו הוספתם רק את צומת הראש.

    עכשיו, כשהצומת הראשי פועל, אפשר להוסיף את המארח השני לאשכול.

  5. מתחברים למארח השני באמצעות SSH.

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE --worker=1
  6. מתקינים את הרכיבים התלויים באמצעות אותו קובץ דרישות כמו בקטע התקנת דרישות.

    pip install -r requirements.txt
    
  7. מתחילים את התהליך של Ray. כדי להוסיף את הצומת הזה לאשכול Ray הקיים, משתמשים בפקודה מהפלט של הפקודה ray start. חשוב להחליף את כתובת ה-IP והיציאה בפקודה הבאה:

    ray start --address='10.130.0.76:6379'

    הפלט אמור להיראות כך:

    Local node IP: 10.130.0.80
    [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    
    --------------------
    Ray runtime started.
    --------------------
    
    To terminate the Ray runtime, run
    ray stop
    
  8. בודקים שוב את הסטטוס של Ray:

    ray status
    

    הפלט אמור להיראות כך:

    ======== Autoscaler status: 2025-01-14 22:45:21.485617 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/416.0 CPU
    0.0/8.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/546.83GiB memory
    0B/238.35GiB object_store_memory
    0.0/2.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    מארח ה-TPU השני הוא עכשיו צומת באשכול. ברשימת המשאבים הזמינים מוצגים עכשיו 8 יחידות TPU ‏ (0.0/8.0 TPU).

הרצת עומס עבודה של Ray

  1. מעדכנים את קטע הקוד כדי להריץ אותו באשכול Ray:

    import os
    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    import ray
    
    import torch.distributed as dist
    import torch_xla.runtime as xr
    from torch_xla._internal import pjrt
    
    # Defines the local PJRT world size, the number of processes per host.
    LOCAL_WORLD_SIZE = 4
    # Defines the number of hosts in the Ray cluster.
    NUM_OF_HOSTS = 4
    GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS
    
    def init_env():
        local_rank = int(os.environ['TPU_VISIBLE_CHIPS'])
    
        pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE)
        xr._init_world_size_ordinal()
    
    # This decorator signals to Ray that the `print_tensor()` function should be run on a single TPU chip.
    @ray.remote(resources={"TPU": 1})
    def print_tensor():
        # Initializes the runtime environment on each Ray worker. Equivalent to
        # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section.
        init_env()
    
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    ray.init()
    
    # Uses Ray to dispatch the function call across available nodes in the cluster.
    tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)]
    ray.get(tasks)
    
    ray.shutdown()
    
  2. מריצים את הסקריפט בצומת הראשי של Ray. מחליפים את ray-workload.py בנתיב לסקריפט.

    python ray-workload.py

    הפלט אמור להיראות כך:

    WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
    xla:0
    xla:0
    xla:0
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

    הפלט מציין שהפונקציה הופעלה בהצלחה בכל מכשיר XLA (8 מכשירים בדוגמה הזו) בפלח TPU מרובה מארחים.

מצב שמתמקד במארח (JAX)

בקטעים הבאים מתואר מצב שמתמקד במארח עם JAX. ‫JAX משתמש בפרדיגמה של תכנות פונקציונלי ותומך בסמנטיקה של תוכנית יחידה ברמה גבוהה, נתונים מרובים (SPMD). במקום שכל תהליך יבצע אינטראקציה עם מכשיר XLA יחיד, קוד JAX מיועד לפעול בכמה מכשירים במארח יחיד בו-זמנית.

‫JAX מיועד למחשוב עתיר ביצועים (HPC) ויכול להשתמש ביעילות ב-TPU לאימון ולהסקת מסקנות בקנה מידה גדול. השימוש במצב הזה מומלץ אם אתם מכירים את המושגים של תכנות פונקציונלי, כדי שתוכלו לנצל את מלוא הפוטנציאל של JAX.

ההוראות האלה מניחות שכבר הגדרתם סביבת Ray ו-TPU, כולל סביבת תוכנה שכוללת את JAX וחבילות קשורות אחרות. כדי ליצור אשכול Ray TPU, פועלים לפי ההוראות במאמר הפעלת אשכול GKE עם יחידות TPU ל-KubeRay.Google Cloud מידע נוסף על שימוש במעבדי TPU עם KubeRay זמין במאמר שימוש במעבדי TPU עם KubeRay.

הרצת עומס עבודה של JAX ב-TPU עם מארח יחיד

סקריפט הדוגמה הבא מראה איך להריץ פונקציית JAX באשכול Ray עם TPU של מארח יחיד, כמו v6e-4. אם יש לכם TPU עם כמה מארחים, הסקריפט הזה מפסיק להגיב בגלל מודל ההפעלה של JAX עם כמה בקרי. מידע נוסף על הרצת Ray ב-TPU עם כמה מארחים זמין במאמר הרצת עומס עבודה של JAX ב-TPU עם כמה מארחים.

  1. יוצרים משתני סביבה לפרמטרים של יצירת TPU.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-a
    export ACCELERATOR_TYPE=v6e-4
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    תיאורים של משתני סביבה

    • PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.
    • TPU_NAME: השם של ה-TPU.
    • ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU.
    • ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU.
    • RUNTIME_VERSION: גרסת התוכנה של Cloud TPU.

  2. כדי ליצור מכונה וירטואלית של TPU מדור v6e עם 4 ליבות, משתמשים בפקודה הבאה:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. מתחברים למכונת ה-TPU הווירטואלית באמצעות הפקודה הבאה:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
  4. מתקינים את JAX ואת Ray ב-TPU.

    pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  5. שומרים את הקוד הבא בקובץ. לדוגמה, ray-jax-single-host.py.

    import ray
    import jax
    
    @ray.remote(resources={"TPU": 4})
    def my_function() -> int:
        return jax.device_count()
    
    h = my_function.remote()
    print(ray.get(h)) # => 4
    

    אם אתם רגילים להריץ את Ray עם GPUs, יש כמה הבדלים חשובים כשמשתמשים ב-TPUs:

    • במקום להגדיר את num_gpus, מציינים את TPU כמשאב בהתאמה אישית ומגדירים את מספר שבבי ה-TPU.
    • מציינים את ה-TPU באמצעות מספר הצ'יפים לכל צומת עובד של Ray. לדוגמה, אם אתם משתמשים ב-v6e-4 ומריצים פונקציה מרוחקת עם TPU שמוגדר ל-4, המארח של ה-TPU צורך את כל הזיכרון.
    • זה שונה מהאופן שבו מעבדי GPU פועלים בדרך כלל, עם תהליך אחד לכל מארח. לא מומלץ להגדיר את TPU למספר שאינו 4.
      • חריג: אם יש לכם מארח יחיד v6e-8 או v5litepod-8, צריך להגדיר את הערך הזה כ-8.
  6. מריצים את הסקריפט.

    python ray-jax-single-host.py

הפעלת עומס עבודה של JAX ב-TPU עם כמה מארחים

בדוגמה הבאה מוצג סקריפט שממחיש איך להריץ פונקציית JAX באשכול Ray עם TPU מרובה מארחים. בדוגמה של הסקריפט נעשה שימוש ב-v6e-16.

  1. יוצרים משתני סביבה לפרמטרים של יצירת TPU.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-a
    export ACCELERATOR_TYPE=v6e-16
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    תיאורים של משתני סביבה

    • PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.
    • TPU_NAME: השם של ה-TPU.
    • ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU.
    • ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU.
    • RUNTIME_VERSION: גרסת התוכנה של Cloud TPU.

  2. כדי ליצור מכונה וירטואלית של TPU מדגם v6e עם 16 ליבות, משתמשים בפקודה הבאה:

    gcloud compute tpus tpu-vm create $TPU_NAME \
       --zone=$ZONE \
       --accelerator-type=$ACCELERATOR_TYPE  \
       --version=$RUNTIME_VERSION
  3. מתקינים את JAX ואת Ray בכל עובדי ה-TPU.

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
       --zone=$ZONE \
       --worker=all \
       --command="pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
  4. שומרים את הקוד הבא בקובץ. לדוגמה, ray-jax-multi-host.py.

    import ray
    import jax
    
    @ray.remote(resources={"TPU": 4})
    def my_function() -> int:
        return jax.device_count()
    
    ray.init()
    num_tpus = ray.available_resources()["TPU"]
    num_hosts = int(num_tpus) # 4
    h = [my_function.remote() for _ in range(num_hosts)]
    print(ray.get(h)) # [16, 16, 16, 16]
    

    אם אתם רגילים להריץ את Ray עם GPUs, יש כמה הבדלים חשובים כשמשתמשים ב-TPUs:

    • בדומה לעומסי עבודה של PyTorch ב-GPU:
    • בניגוד לעומסי עבודה של PyTorch במעבדי GPU, ל-JAX יש תצוגה גלובלית של המכשירים הזמינים באשכול.
  5. מעתיקים את הסקריפט לכל עובדי ה-TPU.

    gcloud compute tpus tpu-vm scp ray-jax-multi-host.py $TPU_NAME: --zone=$ZONE --worker=all
  6. מריצים את הסקריפט.

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
       --zone=$ZONE \
       --worker=all \
       --command="python ray-jax-multi-host.py"

הפעלת עומס עבודה של JAX עם ריבוי-פרוסות

ריבוי-פרוסות (Multislice) מאפשרת להריץ עומסי עבודה בכמה פרוסות TPU בתוך TPU Pod יחיד או בכמה אשכולות ברשת של מרכז הנתונים.

אפשר להשתמש בחבילה ray-tpu כדי לפשט את האינטראקציות של Ray עם חלקי TPU.

התקנת ray-tpu באמצעות pip.

pip install ray-tpu

מידע נוסף על השימוש בחבילה ray-tpu זמין במאמר תחילת העבודה במאגר GitHub. דוגמה לשימוש ב-Multislice מופיעה במאמר בנושא הרצה ב-Multislice.

תזמור עומסי עבודה באמצעות Ray ו-MaxText

מידע נוסף על השימוש ב-Ray עם MaxText זמין במאמר הפעלת משימת אימון באמצעות MaxText.

משאבי TPU ו-Ray

‫Ray מתייחס ל-TPU באופן שונה מ-GPU כדי להתאים להבדלים בשימוש. בדוגמה הבאה, יש בסך הכול תשעה צמתי Ray:

  • צומת ה-head של Ray פועל במכונה וירטואלית n1-standard-16.
  • צמתי העובדים של Ray פועלים על שני v6e-16 TPU. כל TPU כולל ארבעה עובדים.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

תיאורי השדות בדוח 'שימוש במשאבים':

  • CPU: המספר הכולל של מעבדי ה-CPU שזמינים באשכול.
  • TPU: מספר שבבי ה-TPU באשכול.
  • TPU-v6e-16-head: מזהה מיוחד של המשאב שמתאים לעובד 0 בפרוסת TPU. הפעולה הזו חשובה כדי לגשת לפרוסות TPU נפרדות.
  • memory: זיכרון הערימה של העובד שבו האפליקציה משתמשת.
  • object_store_memory: הזיכרון שנעשה בו שימוש כשהאפליקציה יוצרת אובייקטים במאגר האובייקטים באמצעות ray.put, וכשהיא מחזירה ערכים מפונקציות מרוחקות.
  • tpu-group-0 ו-tpu-group-1: מזהים ייחודיים של חלקי ה-TPU. השלב הזה חשוב להרצת משימות בפלחים. השדות האלה מוגדרים ל-4 כי יש 4 מארחים לכל פרוסת TPU ב-v6e-16.