אימון מרובה פרוסות ואימון אלסטי ב-TPU באמצעות Ray Train ב-GKE

במדריך הזה נסביר איך לאמן מודלי שפה גדולים (LLM) כמו Llama 3 70B ב-Google Kubernetes Engine ‏ (GKE) באמצעות MaxText,‏ Ray Train ו-Multislice Trillium TPUs. במדריך הזה מוסבר איך להגדיר את הרשת של מרכז הנתונים המשני הנדרש, לשלוח עומס עבודה של אימון מבוזר ולהריץ אותו בהצלחה ב-32 שבבי TPU פיזיים.

המדריך הזה מיועד לאדמינים של פלטפורמות, למפעילים ולמומחי AI שרוצים ללמוד איך להתמודד עם אתגרי הזיכרון והרשת באימון של מודלים עם 70 מיליארד פרמטרים בפרוסות TPU מבוזרות עם כמה מארחים.

רקע

השילוב של GKE,‏ KubeRay,‏ MaxText ו-TPU מספק פלטפורמה חזקה וניתנת להרחבה לאימון מודלים בקנה מידה גדול. בקטע הזה מתוארות הטכנולוגיות העיקריות שמופיעות במדריך הזה:

JAX

JAX היא ספריית Python לחישוב מערכים ולהמרת תוכניות שמותאמות למאיצים. הספרייה משתמשת בקומפיילר XLA כדי ליצור קוד שעבר אופטימיזציה גבוהה וניתן להרחבה ביעילות במאיצים.

MaxText

MaxText הוא פלטפורמת LLM בקוד פתוח עם ביצועים גבוהים, שנועדה להתאמה אישית ולשינוי קנה מידה. ‫MaxText מבוסס על JAX ועבר אופטימיזציה כדי לפעול ביעילות ב-Cloud TPU.

TPUs

יחידות לעיבוד טנסורים (TPU) הן מאיצים שנוצרו על ידי Google בהתאמה אישית כדי לבצע אופטימיזציה של עומסי עבודה של למידת מכונה. בניגוד למעבדי CPU לשימוש כללי או למעבדי GPU לעיבוד מקבילי, מעבדי TPU מותאמים במיוחד לחישובים של מטריצות וטנסורים, שהם הבסיס ללמידה עמוקה, ולכן הם יעילים במשימה הספציפית הזו. היתרון העיקרי של יחידות TPU הוא הביצועים בקנה מידה נרחב.

במדריך הזה נעשה שימוש ב-TPU Trillium, הדור השישי של TPUs, בתבנית פריסה של Multislice. ב-Cloud TPU Multislice, שני חלקי Cloud TPU או יותר מתקשרים דרך רשת מרכז הנתונים (DCN). ‫Multislice מאפשר אימון מקיף, חסכוני ורחב היקף עם הרחבה כמעט לינארית של עד עשרות אלפי שבבי TPU. מידע נוסף על Multislice זמין במאמר סקירה כללית על Multislice ב-Cloud TPU.

KubeRay

KubeRay הוא אופרטור של Kubernetes שמספק דרך מאוחדת לפריסה, לניהול ולניטור של אפליקציות Ray ב-Kubernetes. אפשר להתקין את האופרטור KubeRay ולנהל אותו באמצעות התוסף Ray ב-GKE. זו הדרך המומלצת לפרוס ולנהל אשכולות Ray ב-GKE.

GKE Dynamic Resource Allocation Network (DRANET)

GKE DRANET (רשת להקצאת משאבים דינמית) היא תכונה שמקצה באופן דינמי מכשירי רשת עם ביצועים גבוהים ל-Pods, תוך עקיפת הרשת הרגילה של Kubernetes, ומאפשרת ביצועים גבוהים ברשת DCN.

מטרות

במדריך הזה מוסבר איך:

  1. מגדירים אשכול GKE עם שני מאגרי צמתים של TPU מרובי-מארחים.
  2. הגדרת DCN משני לתקשורת TPU בין חלקי רשת.
  3. מגדירים את KubeRay לניהול סביבת האימון המבוזרת.
  4. פורסים משאב בהתאמה אישית מסוג RayCluster באמצעות הקצאת משאבים דינמית (DRA) לצירופי רשת.
  5. יוצרים סקריפט אימון ב-Python באמצעות JaxTrainer של Ray Train כדי לתזמן את לולאת האימון של MaxText בכל חלקי ה-TPU.
  6. מריצים אימון בסיסי של Llama 3 8B.
  7. הגדלת קנה מידה עד Llama 3 70B באמצעות חלוקה ל-2D (Tensor Parallelism ו-FSDP) ב-DCN.

לפני שמתחילים

  • נכנסים לחשבון Google Cloud . אם אתם משתמשים חדשים ב- Google Cloud, צרו חשבון כדי שתוכלו להעריך את הביצועים של המוצרים שלנו בתרחישים מהעולם האמיתי. לקוחות חדשים מקבלים בחינם גם קרדיט בשווי 300$ להרצה, לבדיקה ולפריסה של עומסי העבודה.
  • התקינו את ה-CLI של Google Cloud.

  • אם אתם משתמשים בספק זהויות חיצוני (IdP), קודם אתם צריכים להיכנס ל-CLI של gcloud באמצעות המאגר המאוחד לניהול זהויות.

  • כדי לאתחל את ה-CLI של gcloud, הריצו את הפקודה הבאה:

    gcloud init
  • יוצרים או בוחרים Google Cloud פרויקט.

    תפקידים שנדרשים כדי לבחור או ליצור פרויקט

    • Select a project: כדי לבחור פרויקט לא צריך תפקיד IAM ספציפי – אפשר לבחור כל פרויקט שקיבלתם בו תפקיד.
    • יצירת פרויקט: כדי ליצור פרויקט, צריך את התפקיד Project Creator (יצירת פרויקטים) (roles/resourcemanager.projectCreator), שכולל את ההרשאה resourcemanager.projects.create. איך מקצים תפקידים
    • יוצרים Google Cloud פרויקט:

      gcloud projects create PROJECT_ID

      מחליפים את PROJECT_ID בשם של פרויקט Google Cloud שיוצרים.

    • בוחרים את הפרויקט שיצרתם: Google Cloud

      gcloud config set project PROJECT_ID

      מחליפים את PROJECT_ID בשם הפרויקט ב- Google Cloud .

  • מוודאים שהחיוב מופעל בפרויקט Google Cloud .

  • מפעילים את ממשקי ה-API הנדרשים:

    תפקידים שנדרשים להפעלת ממשקי API

    כדי להפעיל ממשקי API, צריך את תפקיד ה-IAM 'אדמין של Service Usage' (roles/serviceusage.serviceUsageAdmin), שכולל את ההרשאה serviceusage.services.enable. איך מקצים תפקידים

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • התקינו את ה-CLI של Google Cloud.

  • אם אתם משתמשים בספק זהויות חיצוני (IdP), קודם אתם צריכים להיכנס ל-CLI של gcloud באמצעות המאגר המאוחד לניהול זהויות.

  • כדי לאתחל את ה-CLI של gcloud, הריצו את הפקודה הבאה:

    gcloud init
  • יוצרים או בוחרים Google Cloud פרויקט.

    תפקידים שנדרשים כדי לבחור או ליצור פרויקט

    • Select a project: כדי לבחור פרויקט לא צריך תפקיד IAM ספציפי – אפשר לבחור כל פרויקט שקיבלתם בו תפקיד.
    • יצירת פרויקט: כדי ליצור פרויקט, צריך את התפקיד Project Creator (יצירת פרויקטים) (roles/resourcemanager.projectCreator), שכולל את ההרשאה resourcemanager.projects.create. איך מקצים תפקידים
    • יוצרים Google Cloud פרויקט:

      gcloud projects create PROJECT_ID

      מחליפים את PROJECT_ID בשם של פרויקט Google Cloud שיוצרים.

    • בוחרים את הפרויקט שיצרתם: Google Cloud

      gcloud config set project PROJECT_ID

      מחליפים את PROJECT_ID בשם הפרויקט ב- Google Cloud .

  • מוודאים שהחיוב מופעל בפרויקט Google Cloud .

  • מפעילים את ממשקי ה-API הנדרשים:

    תפקידים שנדרשים להפעלת ממשקי API

    כדי להפעיל ממשקי API, צריך את תפקיד ה-IAM 'אדמין של Service Usage' (roles/serviceusage.serviceUsageAdmin), שכולל את ההרשאה serviceusage.services.enable. איך מקצים תפקידים

    gcloud services enable container.googleapis.com cloudbuild.googleapis.com
  • מעניקים תפקידים לחשבון המשתמש. מריצים את הפקודה הבאה לכל אחד מהתפקידים הבאים ב-IAM: roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editor

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    מחליפים את מה שכתוב בשדות הבאים:

    • PROJECT_ID: מזהה הפרויקט.
    • USER_IDENTIFIER: המזהה של חשבון המשתמש . לדוגמה, myemail@example.com.
    • ROLE: תפקיד ה-IAM שאתם מקצים לחשבון המשתמש.
  • במדריך הזה נעשה שימוש ב-TPU Trillium ‏ (v6e), לכן צריך לבחור אזור או אזור זמין. מידע נוסף זמין במאמר בנושא מכסות של Cloud TPU.

הכנת הסביבה

במדריך הזה משתמשים ב-Cloud Shell. ב-Cloud Shell מותקנים מראש כלי שורת הפקודה gcloud, helm ו-kubectl, שבהם נעשה שימוש במדריך הזה.

  1. עוברים אל Google Cloud המסוף.

  2. בחלק העליון של חלון המסוף, לוחצים על הלחצן Activate Cloud Shell כפתור הפעלת Shell. Google Cloud

    בחלק התחתון של המסוף ייפתח סשן של Cloud Shell בתוך מסגרת חדשה ותופיע הודעה של שורת הפקודה.Google Cloud

  3. במסוף, משכפלים את המאגר kubernetes-engine-samples:

    git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git
    
  4. עוברים לספרייה שמכילה את הקבצים לדוגמה:

    cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext
    
  5. יוצרים ומפעילים סביבה וירטואלית של Python:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  6. מתקינים את Ray CLI:

    pip install "ray[default]==2.55.0"
    
  7. מגדירים את משתני הסביבה הבאים:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export CLUSTER_VERSION=1.35.2-gke.1842000
    

    מחליפים את מה שכתוב בשדות הבאים:

    • GS_BUCKET: שם הקטגוריה ב-Cloud Storage.
    • KSA_NAME: השם של חשבון השירות של Kubernetes.
    • CLUSTER_NAME: השם של האשכול החדש.
    • REGION: האזור שבו קיבולת TPU Trillium זמינה.
    • ZONE: האזור שבו קיבולת TPU Trillium זמינה. מידע נוסף זמין במאמר זמינות של TPU ב-GKE.

הגדרת רשת אשכולות ל-Cloud TPU Multislice

בפרוסת TPU מרובת מארחים, מכשירי ה-TPU מתקשרים באמצעות חיבורים מהירים בין הצ'יפים. עם זאת, כשמריצים משימות Multislice, פרוסות ה-TPU צריכות לתקשר ביניהן דרך ה-DCN. רשתות Pod רגילות של Kubernetes עלולות ליצור צוואר בקבוק בתעבורה הזו. סוג המכונה ct6e-standard-4t מגובה בכמה כרטיסי ממשק רשת (NIC) פיזיים. כדי להשיג את הביצועים הכי טובים, יוצרים שתי רשתות VPC נוספות ומשתמשים ב-GKE DRANET כדי לחבר אותן ישירות ל-Ray Pods.

  1. יוצרים את שתי רשתות ה-VPC הנוספות עם יחידת אימון מקסימלית (MTU) גדולה:

    gcloud compute networks create ${CLUSTER_NAME}-net-1 \
      --subnet-mode=custom \
      --mtu=8896
    
    gcloud compute networks create ${CLUSTER_NAME}-net-2 \
      --subnet-mode=custom \
      --mtu=8896
    
  2. יוצרים את רשתות המשנה הייעודיות:

    gcloud compute networks subnets create tpu-subnet-1 \
      --network=${CLUSTER_NAME}-net-1 \
      --region=${REGION} \
      --range=10.50.0.0/16
    
    gcloud compute networks subnets create tpu-subnet-2 \
      --network=${CLUSTER_NAME}-net-2 \
      --region=${REGION} \
      --range=10.60.0.0/16
    

יצירת אשכול GKE

אפשר להגדיר את KubeRay ב-TPU באשכול GKE Autopilot או באשכול רגיל. מומלץ להשתמש באשכול Autopilot כדי ליהנות מחוויית Kubernetes מנוהלת באופן מלא. כדי לבחור את מצב הפעולה של GKE שהכי מתאים לעומסי העבודה שלכם, אפשר לעיין במאמר מידע על מצבי הפעולה של GKE.

כדי להשתמש ב-DRANET מנוהל של GKE, האשכול צריך להיות בגרסה 1.35.2-gke.1842000 ואילך במצב Autopilot, או בגרסה 1.34.1-gke.1829001 ואילך במצב Standard. במדריך הזה נעשה שימוש בגרסה 1.35.2-gke.1842000.

טייס אוטומטי

  1. ב-Cloud Shell, מריצים את הפקודה הבאה:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION \
        --cluster-version=${CLUSTER_VERSION}
    
  2. כדי לתקשר עם האשכול, צריך להגדיר את kubectl :

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$REGION
    

רגילה

  1. ב-Cloud Shell, יוצרים אשכול Standard שמופעל בו התוסף Ray operator באמצעות הפקודה הבאה:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator,GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --enable-dataplane-v2 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE \
        --cluster-version=${CLUSTER_VERSION}
    

    הפקודה הזו גם מפעילה את GcsFuseCsiDriver, שמאפשר ל-Pods לטעון קטגוריות של Cloud Storage כמערכות קבצים מקומיות. יצירת האשכול עשויה להימשך כמה דקות.

  2. כדי לתקשר עם האשכול, צריך להגדיר את kubectl:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    
  3. יוצרים את המאגר הראשון של צמתים מסוג TPU slice עם מארחים מרובים, עם הפעלת GKE DRANET:

    gcloud container node-pools create v6e-16-0 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    
  4. יוצרים את מאגר הצמתים השני של פרוסת ה-TPU:

    gcloud container node-pools create v6e-16-1 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4 \
        --additional-node-network=network=${CLUSTER_NAME}-net-1,subnetwork=tpu-subnet-1 \
        --additional-node-network=network=${CLUSTER_NAME}-net-2,subnetwork=tpu-subnet-2 \
        --node-labels=cloud.google.com/gke-networking-dra-driver=true \
        --enable-gvnic \
        --scopes=https://www.googleapis.com/auth/cloud-platform
    

‫GKE מקצה מאגר צמתים שמורכב מארבע מכונות וירטואליות של TPU Trillium ‏ (v6e), שמוגדרות יחד כפרוסת TPU מרובת מארחים עם טופולוגיה של 4x4. מאגר הצמתים הזה מוכן לעומסי עבודה של אימון מבוזר.

באשכול GKE שמופעל בו Ray operator, המערכת מתקינה אוטומטית את KubeRay ואת KubeRay TPU webhook באשכול.

הגדרת קטגוריה של Cloud Storage וחשבון שירות

  1. יוצרים קטגוריה של Cloud Storage לנקודות ביקורת משותפות בין צמתי ה-TPU עם כמה מארחים.

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. כדי לאפשר גישה לקטגוריה של Cloud Storage, יוצרים חשבון שירות של Kubernetes:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. כדי להעניק גישה לקטגוריה של Cloud Storage, מוסיפים לחשבון השירות את קישורי המדיניות הנדרשים ב-IAM:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

יצירת סקריפט ההדרכה

סקריפט maxtext_multi_slice_trainer.py משתמש ב-JaxTrainer של Ray Train כדי להריץ משימת אימון מבוזרת של MaxText בשני חלקי TPU. הסקריפט מגדיר את סביבת האימון עבור שמונה עובדי TPU מרובי-מארחים ומריץ את משימת האימון של MaxText בכל צומת עובד. הפונקציה train_loop_per_worker עוטפת את נקודת הכניסה הראשית של MaxText, ומשתמשת במתזמן המבוזר של Ray כדי להריץ את מאמן MaxText בפרוסת TPU מרובת מארחים:

import os
from absl import app
import logging
from typing import Sequence
import ray
from ray.train.v2.api.config import ScalingConfig, RunConfig
from ray.train.v2.jax import JaxTrainer

def train_loop_per_worker(config):
    import maxtext
    from maxtext.trainers.pre_train.train import main as maxtext_main

    argv = config["argv"]
    maxtext_main(argv)

def main(argv: Sequence[str]):
    # Convert the config file path to an absolute path
    argv = list(argv)
    if len(argv) > 1:
        argv[1] = os.path.abspath(argv[1])

    trainer = JaxTrainer(
        train_loop_per_worker=train_loop_per_worker,
        train_loop_config={"argv": argv},
        scaling_config=ScalingConfig(
            use_tpu=True,
            num_workers=8,
            topology="4x4",
            accelerator_type="TPU-V6E",
            resources_per_worker={"TPU": 4},
            placement_strategy="SPREAD",
        ),
        run_config=RunConfig(
            name="maxtext_jaxtrainer",
            worker_runtime_env={
                "uv": {
                    # maxtext requires some additional deps
                    "packages": ["maxtext[tpu]==0.2.1"],
                    "uv_pip_install_options": ["--resolution=lowest"]
                },
            },
        ),
    )
    result = trainer.fit()
    logging.info("Training complete!")
    ray.shutdown()

if __name__ == "__main__":
    app.run(main)

הסקריפט הקודם מגדיר מופע JaxTrainer שמבקש שמונה עובדים וטופולוגיה של 4x4. באופן פנימי, Ray מספק SlicePlacementGroup בשני חלקי ה-TPU, ועוזר לוודא שעובדי Ray Train פועלים באופן אטומי בשני החלקים, עם עובד אחד לכל מארח.

אימון המודל

  1. קובץ המניפסט ray-cluster.tpu-multi-slice.yaml בספרייה הנוכחית מגדיר את המשאב המותאם אישית RayCluster. קובץ המניפסט הזה כולל את DRANET ResourceClaimTemplate כדי להקצות את מכשירי הרשת ל-GKE DRANET ול-Multislice:

    apiVersion: resource.k8s.io/v1
    kind: ResourceClaimTemplate
    metadata:
      name: two-netdev
    spec:
      spec:
        devices:
          requests:
          - name: req-netdev
            exactly:
              deviceClassName: netdev.google.com
              allocationMode: ExactCount
              count: 2
    ---
    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: rayproject/ray:nightly-py312-tpu
                imagePullPolicy: Always
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
            nodeSelector:
              iam.gke.io/gke-metadata-server-enabled: "true"
      workerGroupSpecs:
        - replicas: 2
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: 
            metrics-export-port: "8082"
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              resourceClaims:
              - name: netdev
                resourceClaimTemplateName: two-netdev
              containers:
                - name: ray-worker
                  image: rayproject/ray:nightly-py312-tpu
                  imagePullPolicy: Always
                  resources:
                    claims:
                    - name: netdev
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: MEGASCALE_NUM_SLICES
                      value: "2"
                    - name: MEGASCALE_PORT
                      value: "9915"
                    - name: JAX_PLATFORMS
                      value: tpu,cpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                    - name: LIBTPU_INIT_ARGS
                      value: "--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --megascale_grpc_interface_prefixes=eth1,eth2,lo"
                  securityContext:
                    privileged: true
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs,uid=1000,gid=1000,dir-mode=775,file-mode=664,file-cache:max-size-mb:-1"
              nodeSelector:
                iam.gke.io/gke-metadata-server-enabled: "true"
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
    

    הגדרת ה-RayCluster שמופיעה למעלה יוצרת קבוצת עובדים של TPU עם שמונה עובדים (numOfHosts: 4) לכל עותק, עם שני עותקים. כל עובד מבקש ארבעה שבבי TPU‏ (google.com/tpu: "4"). כל אחד מהעובדים מתוזמן בצומת TPU Trillium‏ (tpu-v6e-slice), שהוא חלק מאותו פלח מרובה מארחים שנמצא באותו מיקום. ‫KubeRay משנה את הגודל של כל ארבעת העובדים בפרוסה באופן אטומי. משתני הסביבה הנדרשים של JAX, וגם Pod Affinities לתזמון, מופעלים על ידי GKE באמצעות webhook לשינוי.

  2. כדי ליצור את RayCluster, מפעילים את המניפסט:

    envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f -
    
  3. מוודאים שהאשכול מוכן ופועל:

    kubectl get rayclusters maxtext-tpu-cluster
    

    הפלט אמור להיראות כך:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY         GPUS   STATUS   AGE
    maxtext-tpu-cluster   8                 8                   72     1579277216Ki   0      ready    2m11s
    
  4. כדי לגשת ללוח הבקרה של Ray דרך שירות ה-Ray head, צריך ליצור סשן של העברת פורטים:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. מוודאים שאפשר להגיע אל RayCluster מהסביבה המקומית:

    ray list nodes --address http://localhost:8265
    

    הפלט אמור להיראות כך:

    ray list nodes --address http://localhost:8265
    2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
    2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads.
    
    ======== List: 2026-04-21 10:20:20.945431 ========
    Stats:
    ------------------------------
    Total: 9
    
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP     IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                   LABELS
    0  4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1  10.68.9.5   False           ALIVE                     10.68.9.5    CPU: 8.0                          ray.io/accelerator-type: TPU-V6E
                                                                                                                                    TPU: 4.0                          ray.io/node-group: tpu-group
                                                                                                                                    accelerator_type:TPU-V6E: 1.0     ray.io/node-id: 4f0e4d742...
                                                                                                                                    memory: 186.265 GiB               ray.io/tpu-pod-type: v6e-16
                                                                                                                                    node:10.68.9.5: 1.0               ray.io/tpu-slice-name: tpu-group-0
                                                                                                                                    object_store_memory: 186.265 GiB  ray.io/tpu-topology: 4x4
                                                                                                                                    tpu-group-0: 1.0                  ray.io/tpu-worker-id: '1'
    ...
    6  ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938  10.68.2.9   True            ALIVE                     10.68.2.9    CPU: 8.0                          ray.io/node-group: headgroup
                                                                                                                                    memory: 16.000 GiB                ray.io/node-id: ce7056807...
                                                                                                                                    node:10.68.2.9: 1.0
                                                                                                                                    node:__internal_head__: 1.0
                                                                                                                                    object_store_memory: 4.765 GiB
    ...
    
  6. מורידים את קובץ התצורה הבסיסי של MaxText. הקובץ הזה נדרש על ידי סקריפט האימון כדי להגדיר את היפר-הפרמטרים של ברירת המחדל של המודל:

    curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml
    
  7. שולחים את סקריפט JaxTrainer אל RayCluster ומוודאים ש-RayJob הושלם בהצלחה:

Llama 3 8B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=4 \
      max_target_length=4096 \
      model_name=llama3-8b \
      steps=100 \
      ici_fsdp_parallelism=4 \
      ici_tensor_parallelism=4 \
      run_name=rayjob-multi-slice

Llama 3 70B

ray job submit \
  --address http://localhost:8265 \
  --working-dir . \
  --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
  -- python maxtext_multi_slice_trainer.py \
      base.yml \
      base_output_directory=/data/ \
      dataset_type=synthetic \
      per_device_batch_size=2 \
      max_target_length=4096 \
      model_name=llama3-70b \
      steps=100 \
      ici_tensor_parallelism=4 \
      ici_fsdp_parallelism=4 \
      dcn_fsdp_parallelism=2 \
      dcn_data_parallelism=1 \
      remat_policy=full \
      run_name=rayjob-multi-slice-70b-fsdp

הפקודה הקודמת שולחת את סקריפט Python, שקורא לקוד JaxTrainer Ray אל RayCluster. הפקודה ray job submit כוללת כמה ארגומנטים ספציפיים ל-MaxText שמועברים להגדרת המודל.

במסוף, הפלט של משימת Llama 3 70B אמור להיראות כך:

[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]

------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------

הרצת אימון גמיש של כמה פרוסות במכונות וירטואליות של Spot

כשמשתמשים במאיצים מבוקשים כמו TPU, שימוש במכונות וירטואליות מסוג Spot עשוי להוזיל משמעותית את העלויות. עם זאת, יכול להיות שמכונות וירטואליות מסוג Spot יידחקו באופן בלתי צפוי.

‫Ray Train תומך באימון גמיש, שמאפשר למשימה שלכם לשנות באופן דינמי את מספר חלקי ה-TPU שמשתתפים באימון, בלי שהאימון ייכשל. אם מתבצעת קדימה של חלק, Ray משהה את לולאת האימון, מחכה שעובדי ה-worker הנותרים יתארגנו מחדש, משחזר מנקודת הבדיקה האחרונה של MaxText וממשיך את האימון בטביעת הרגל הקטנה יותר.

כדי להפעיל אימון גמיש, משנים את הפרמטר num_workers ב-ScalingConfig ממספר שלם סטטי לטופל שמייצג את (minimum_workers, maximum_workers). בנוסף, מוסיפים FailureConfig(max_failures=3) ל-RunConfig, שמורה ל-Ray Train לנסות שוב את לולאת האימון עד 3 פעמים במקום להפסיק את העבודה לגמרי כשמתבצעת הקצאה מראש של worker.

עדכון הסקריפט של Ray Train

  1. הסקריפט maxtext_elastic_trainer.py בספרייה הנוכחית מאפשר אימון גמיש. שימו לב שהערך שמוגדר הוא num_workers=(4,8), שמורה ל-Ray להמשיך אם יש לפחות פרוסת 16 שבבים אחת (ארבעה תהליכי עבודה), אבל להגדיל את מספר הפרוסות לשתיים (שמונה תהליכי עבודה) אם אפשר. היא כוללת FailureConfig כדי להפעיל אימון גמיש, להגדיר את מספר הניסיונות החוזרים ולעזור להבטיח שהעבודה תמשיך גם אם היא תיקטע:

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig, FailureConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        import maxtext
        from maxtext.trainers.pre_train.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        # Convert the config file path to an absolute path
        argv = list(argv)
        if len(argv) > 1:
            argv[1] = os.path.abspath(argv[1])
    
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                # This tells Ray to scale the number of workers between 4 and 8 (i.e. 1 to 2 TPU slices).
                num_workers=(4,8),
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                # Define a FailureConfig to enable fault tolerance by automatically restarting failed workers.
                failure_config=FailureConfig(max_failures=3),
                worker_runtime_env={
                    "uv": {
                        # maxtext requires some additional deps
                        "packages": ["maxtext[tpu]==0.2.1"],
                        "uv_pip_install_options": ["--resolution=lowest"]
                    },
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
    
  2. שולחים את העבודה באמצעות Ray Job CLI. חשוב לספק שם ייחודי לנקודת הבדיקה run_name כדי שלא יהיה קונפליקט עם הרצות קודמות.

    ray job submit \
      --address http://localhost:8265 \
      --working-dir . \
      --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \
      -- python maxtext_elastic_trainer.py \
          base.yml \
          base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=4 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-elastic-8b
    
  3. כדי לדמות סיום של צומת או קדימה במהלך אימון, מוחקים Pod.

    kubectl delete pod $(kubectl get pods -l ray.io/node-type=worker -o jsonpath='{.items[0].metadata.name}')
    

הטרמינל מתעד כשל של עובד, אבל בקר התזמור שומר על המשימה פעילה וממשיך אותה באופן אוטומטי מנקודת הבדיקה /data/rayjob-elastic-8b/checkpoints אחרי שהטופולוגיה המינימלית זמינה.

מכיוון ש-MaxText מחשב מחדש באופן דינמי את רשת המכשירים אחרי הפסקה, לא צריך לכתוב לוגיקה מותאמת אישית כדי לטפל בפיצול מחדש של נקודות ביקורת כשהטופולוגיה מצטמצמת. הכלי Orbax של JAX לבדיקת נקודות עצירה יחלק מחדש באופן אוטומטי את המשקלים השמורים לפריסה הפיזית החדשה לפני שימשיך את לולאת האימון. בפלט הבא אפשר לראות שבקר Ray Train מזהה משאבי TPU חדשים שזמינים באשכול, ומבצע פעולת שינוי גודל מ-slice אחד (ארבעה עובדים) לשני slices (שמונה עובדים) במהלך האימון.

...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048   20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8

הסרת המשאבים

כדי להימנע מחיובים בחשבון Google Cloud על המשאבים שבהם השתמשתם במדריך הזה, אתם יכולים למחוק את הפרויקט שמכיל את המשאבים או להשאיר את הפרויקט ולמחוק את המשאבים בנפרד.

  1. מחיקת RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. מחיקת אשכול GKE:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. מוחקים את הקטגוריה של Cloud Storage:

    gsutil rm -r gs://${GS_BUCKET}
    

המאמרים הבאים