הרצת עומס עבודה של אצווה באמצעות Pathways

לצורך המאמר הזה, עומסי עבודה באצווה מוגדרים כעומסי עבודה של JAX שמופעלים עד לסיום ומוצבים באותו אשכול GKE כמו אשכול Pathways, במיוחד לצד רכיבי בקר Pathways (שרת proxy של IFRT ומנהל משאבים של Pathways). השלמת עומס העבודה של JAX מסיימת את רכיבי האשכול של Pathways. במדריך הזה נעשה שימוש בעומס עבודה של אימון JAX כדי להדגים את זה.

לפני שמתחילים

חשוב לוודא שיש לכם:

יצירת תמונת אימון באמצעות Maxtext

MaxText הוא פרויקט קוד פתוח של מודל שפה גדול (LLM) שפותח על ידי Google. היא כתובה ב-JAX ומתוכננת להיות בעלת ביצועים גבוהים וניתנת להרחבה, והיא פועלת ביעילות במעבדי TPU ובמעבדי GPU של Google Cloud.

כדי ליצור קובץ אימג' של MaxText Docker באמצעות הגרסה היציבה העדכנית ביותר של JAX ממאגר ה-OSS GitHub, מריצים את הפקודה הבאה:

git clone https://github.com/AI-Hypercomputer/maxtext
cd maxtext/dependencies/scripts
gcloud config set project PROJECT
bash ./docker_build_dependency_image.sh MODE=stable
gcloud auth configure-docker
bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner # This script needs bash version >= 4.2 to execute.

הפקודה הזו מעבירה את תמונת Kubernetes של MaxText אל gcr.io/$PROJECT/${USER}_runner. אפשר להשתמש בקובץ האימג' הזה של Docker כדי להריץ אימון ב-TPU באמצעות קצה העורפי של Pathways.

הפעלת עומס עבודה של אצווה באמצעות PathwaysJob API

קובץ המניפסט הבא פורס את רכיבי Pathways ומריץ עומס עבודה של MaxText באמצעות PathwaysJob API. עומס העבודה מוכל בתוך קונטיינר main ומפעיל את train.py.

מעתיקים את קוד ה-YAML הבא לקובץ בשם pathways-job-batch-training.yaml ומעדכנים את הערכים שניתנים לעריכה.

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  maxRestarts: MAX_RESTARTS
  workers:
  - type: TPU_MACHINE_TYPE
    topology: TOPOLOGY
    numSlices: WORKLOAD_NODEPOOL_COUNT
  pathwaysDir: gs://BUCKET_NAME
  controller:
    deploymentMode: default
    template:
      spec:
        containers:
        - name: main
          image: gcr.io/PROJECT/USER_runner
          command:
          - bash
          - -c
          - |
            python3 -m MaxText.train /deps/src/MaxText/configs/base.yml \
            base_output_directory=gs://BUCKET_NAME \
            run_name=RUN_NAME \
            per_device_batch_size=1 \
            enable_checkpointing=false \
            remat_policy=full \
            global_parameter_scale=1 \
            steps=20 \
            max_target_length=2048 \
            use_iota_embed=true \
            reuse_example_batch=1 \
            dataset_type=synthetic \
            attention=flash \
            gcs_metrics=True \
            enable_single_controller=True

מחליפים את מה שכתוב בשדות הבאים:

  • USER : מזהה המשתמש ב- Google Cloud
  • MAX_RESTARTS : המספר המקסימלי של הפעמים שאפשר להפעיל מחדש את המשימה
  • TPU_MACHINE_TYPE : סוג מכונת ה-TPU
  • TOPOLOGY : הטופולוגיה של TPU v4 ואילך. מידע נוסף על גרסאות TPU וטופולוגיות נתמכות זמין במאמר גרסאות TPU.
  • WORKLOAD_NODEPOOL_COUNT : מספר מאגרי הצמתים שמשמשים את עומס העבודה של Pathways
  • BUCKET_NAME : קטגוריה של Cloud Storage לאחסון קבצים זמניים
  • PROJECT : מזהה הפרויקט ב- Google Cloud
  • RUN_NAME : שם שהמשתמש מקצה כדי לזהות את ההרצה של תהליך העבודה

אפשר לפרוס את קובץ ה-YAML‏ PathwaysJob באופן הבא:

kubectl apply -f pathways-job-batch-training.yaml

כדי להציג את מכונת PathwaysJob שנוצרה באמצעות הפקודה הקודמת, משתמשים בפקודה:

kubectl get pathwaysjob

הפלט אמור להיראות כך:

NAME             AGE
pathways-trial   9s

כדי לשנות מאפיין של מופע PathwaysJob, צריך למחוק את המופע PathwaysJob, לשנות את קובץ ה-YAML ולהחיל אותו כדי ליצור מופע PathwaysJob חדש.

כדי לעקוב אחרי התקדמות עומס העבודה, עוברים אל Logs Explorer עבור מאגר JAX על ידי בחירה באפשרות main במסנן Container Name.

אמורים להופיע יומנים כמו אלה שבהמשך, שמציינים שהאימון מתקדם. עומס העבודה יושלם אחרי 30 שלבים.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

כדי למחוק את מופע PathwaysJob, אפשר להשתמש בפקודה הבאה:

kubectl delete -f pathways-job-batch-training.yaml

הפעלת עומס עבודה באצווה באמצעות XPK

עכשיו אפשר לשלוח את תמונת ה-Docker של Maxtext שנבנתה מראש באמצעות XPK עם אותה פקודה שבה השתמשתם קודם.

xpk workload create-pathways \
--workload=WORKLOAD \
--cluster=CLUSTER \
--num-slices=WORKLOAD_NODEPOOL_COUNT \
--tpu-type=TPU_TYPE \
--project=PROJECT \
--zone=ZONE \
--docker-image='gcr.io/PROJECT/USER_runner' \
--command="python3 -m MaxText.train /deps/src/MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"

מחליפים את מה שכתוב בשדות הבאים:

  • WORKLOAD: שם ייחודי לזיהוי עומס העבודה
  • CLUSTER: השם של אשכול GKE
  • WORKLOAD_NODEPOOL_COUNT : המספר המקסימלי של הפעמים שאפשר להפעיל מחדש את העבודה
  • TPU_TYPE: סוג ה-TPU מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי TPU נתמכים לכל גרסת TPU זמין במאמר גרסאות TPU.
  • PROJECT : מזהה הפרויקט ב- Google Cloud
  • ZONE: האזור שבו אתם מתכננים להריץ את עומס העבודה
  • USER : מזהה המשתמש ב- Google Cloud
  • RUN_NAME : שם שהמשתמש מקצה כדי לזהות את ההרצה של תהליך העבודה

הפלט אמור להיראות כך:

[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT

משתמשים בקישור שמופיע בפלט של פקודת ה-XPK הקודמת כדי לעקוב אחרי התקדמות העומס. כדי לסנן את היומנים של מאגר JAX, בוחרים באפשרות jax-tpu מתחת למסנן Container Name (שם המאגר).

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

עומס העבודה מסתיים אחרי מספר השלבים שצוין. אם רוצים לסיים את התהליך לפני הזמן, משתמשים בפקודה הבאה:

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

המאמרים הבאים