אימון מודל באמצעות TPU v5e

עם טביעת רגל קטנה יותר של 256 שבבים לכל Pod,‏ TPU v5e מותאם להיות מוצר בעל ערך גבוה לאימון, לכוונון עדין ולהצגה של טרנספורמציה, יצירת תמונות לפי טקסט ורשת עצבית קונבולוציונית (CNN). מידע נוסף על שימוש ב-Cloud TPU v5e להצגת מודלים זמין במאמר הסקת מסקנות באמצעות v5e.

מידע נוסף על חומרה והגדרות של Cloud TPU v5e TPU זמין במאמר בנושא TPU v5e.

קדימה, מתחילים

בקטעים הבאים מוסבר איך להתחיל להשתמש ב-TPU v5e.

בקשת מכסה

כדי להשתמש ב-TPU v5e לאימון, צריך מכסה. יש סוגים שונים של מכסות ל-TPU על פי דרישה, ל-TPU שמורים ולמכונות וירטואליות מסוג TPU Spot. אם אתם משתמשים ב-TPU v5e להסקת מסקנות, נדרשות מכסות נפרדות. מידע נוסף על מכסות זמין במאמר מכסות. כדי לבקש מכסת TPU v5e, פנו למחלקת המכירות של Cloud.

יצירת חשבון ופרויקט Google Cloud

כדי להשתמש ב-Cloud TPU, אתם צריכים חשבון ופרויקט ב- Google Cloud . מידע נוסף זמין במאמר הגדרת סביבת Cloud TPU.

יצירת Cloud TPU

השיטה המומלצת היא להקצות Cloud TPU v5es כמשאבים בתור באמצעות הפקודה queued-resource create. מידע נוסף זמין במאמר בנושא ניהול משאבים בתור.

אפשר גם להשתמש ב-Create Node API ‏ (gcloud compute tpus tpu-vm create) כדי להקצות Cloud TPU v5es. מידע נוסף זמין במאמר ניהול משאבי TPU.

מידע נוסף על ההגדרות הזמינות של v5e לאימון מופיע במאמר סוגי Cloud TPU v5e לאימון.

הגדרת ה-Framework

בקטע הזה מתואר תהליך ההגדרה הכללי לאימון מודלים בהתאמה אישית באמצעות JAX או PyTorch עם TPU v5e.

הוראות להגדרת הסקת מסקנות מופיעות במאמר מבוא להסקת מסקנות בגרסה v5e.

מגדירים כמה משתני סביבה:

export PROJECT_ID=your_project_ID
export ACCELERATOR_TYPE=v5litepod-16
export ZONE=us-west4-a
export TPU_NAME=your_tpu_name
export QUEUED_RESOURCE_ID=your_queued_resource_id

הגדרה של JAX

אם יש לכם צורות של פרוסות עם יותר מ-8 שבבים, יהיו לכם כמה מכונות וירטואליות בפרוסה אחת. במקרה כזה, צריך להשתמש בדגל --worker=all כדי להריץ את ההתקנה בכל מכונות ה-TPU הווירטואליות בשלב אחד, בלי להשתמש ב-SSH כדי להתחבר לכל אחת מהן בנפרד:

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

תיאורים של דגלי פקודות

  • TPU_NAME: מזהה הטקסט שהמשתמש הקצה ל-TPU שנוצר כשהוקצתה בקשת המשאב שהוכנסה לתור.
  • PROJECT_ID: Google Cloud שם הפרויקט. משתמשים בפרויקט קיים או יוצרים פרויקט חדש בקטע הגדרת Google Cloud הפרויקט
  • ZONE: מידע על התחומים הנתמכים זמין במסמך TPU regions and zones (אזורים ותחומים של TPU).
  • worker: מכונת ה-TPU VM שיש לה גישה למעבדי ה-TPU הבסיסיים.

כדי לבדוק את מספר המכשירים, מריצים את הפקודה הבאה (הפלט שמוצג כאן נוצר באמצעות פרוסת v5litepod-16). הקוד הזה בודק שהכול מותקן בצורה נכונה. הוא עושה זאת על ידי בדיקה של JAX כדי לוודא שהוא מזהה את ליבות ה-Tensor של Cloud TPU ויכול להריץ פעולות בסיסיות:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

הפלט ייראה בערך כך:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() מציג את המספר הכולל של הצ'יפים בפלח הנתון. ‫jax.local_device_count() מציין את מספר השבבים שאפשר לגשת אליהם ממכונה וירטואלית אחת בפלח הזה.

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

הפלט ייראה בערך כך:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

כדי להתחיל להשתמש ב-JAX לאימון מודלים בגרסה v5e, כדאי לנסות את המדריכים ל-JAX שמופיעים במסמך הזה.

הגדרה של PyTorch

הערה: גרסה v5e תומכת רק ב-PJRT runtime ו-PyTorch 2.1+‎ ישתמש ב-PJRT כ-runtime ברירת המחדל לכל גרסאות ה-TPU.

בקטע הזה מוסבר איך להתחיל להשתמש ב-PJRT בגרסה v5e עם PyTorch/XLA באמצעות פקודות לכל העובדים.

התקנת יחסי תלות

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip install mkl mkl-include
      pip install tf-nightly tb-nightly tbp-nightly
      pip install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

מחליפים את PYTORCH_VERSION בגרסה של PyTorch שבה רוצים להשתמש. המאפיין PYTORCH_VERSION משמש לציון אותה גרסה של PyTorch/XLA. מומלצת גרסה 2.6.0.

מידע נוסף על גרסאות של PyTorch ו-PyTorch/XLA זמין במאמרים PyTorch - Get Started ו-PyTorch/XLA releases.

מידע נוסף על התקנת PyTorch/XLA זמין במאמר בנושא התקנת PyTorch/XLA.

אם מתקבלת שגיאה כשמנסים להתקין את קובצי ה-wheel של torch, torch_xla או torchvision, כמו pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end or semicolon (after name and no valid version specifier) torch==nightly+20230222, צריך לשנמך את הגרסה באמצעות הפקודה הבאה:

pip3 install setuptools==62.1.0

הרצת סקריפט באמצעות PJRT

unset LD_PRELOAD

הדוגמה הבאה מציגה שימוש בסקריפט Python לביצוע חישוב במכונה וירטואלית v5e:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      unset LD_PRELOAD
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

יווצר פלט שדומה לזה:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')

כדי להתחיל באימון של גרסה 5e באמצעות PyTorch, אפשר לנסות את המדריכים של PyTorch שמופיעים במסמך הזה.

בסיום הסשן, מוחקים את ה-TPU ואת המשאב שמוכנס לתור. כדי למחוק משאב בתור, צריך למחוק את הפרוסה ואז את המשאב בתור בשני שלבים:

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

אפשר להשתמש בשני השלבים האלה גם כדי להסיר בקשות למשאבים שנמצאות בתור ומצבן הוא FAILED.

דוגמאות ל-JAX/FLAX

בקטעים הבאים מפורטות דוגמאות לאימון מודלים של JAX ו-FLAX ב-TPU v5e.

אימון ImageNet בגרסה v5e

במדריך הזה מוסבר איך לאמן את ImageNet בגרסה v5e באמצעות נתוני קלט פיקטיביים. אם רוצים להשתמש בנתונים אמיתיים, אפשר לעיין בקובץ ה-README ב-GitHub.

הגדרה

  1. יוצרים משתני סביבה:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    תיאורים של משתני סביבה

    • PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.
    • TPU_NAME: השם של ה-TPU.
    • ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU.
    • ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU.
    • RUNTIME_VERSION: גרסת התוכנה של Cloud TPU.
    • SERVICE_ACCOUNT: כתובת האימייל של חשבון השירות. אפשר למצוא אותו בדף Service Accounts במסוף Google Cloud .

      לדוגמה: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    • QUEUED_RESOURCE_ID: מזהה הטקסט שהמשתמש הקצה לבקשת המשאב שנוספה לתור.

  2. יצירת משאב TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    תוכלו להתחבר ב-SSH ל-TPU VM כשהמשאב בתור יהיה במצב ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    כשה-QueuedResource נמצא במצב ACTIVE, הפלט ייראה כך:

     state: ACTIVE
    
  3. מתקינים את הגרסה החדשה ביותר של JAX ו-jaxlib:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. משכפלים את מודל ImageNet ומתקינים את הדרישות המתאימות:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
    
  5. כדי ליצור נתונים מזויפים, המודל צריך מידע על המימדים של מערך הנתונים. אפשר לאסוף את הנתונים האלה מהמטא-נתונים של קבוצת הנתונים ImageNet:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
    

אימון המודל

אחרי שמשלימים את כל השלבים הקודמים, אפשר לאמן את המודל.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"

מחיקת ה-TPU והמשאב בתור

בסיום הסשן, מוחקים את ה-TPU ואת המשאב שמוכנס לתור.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

מודלים של Hugging Face FLAX

מודלים של Hugging Face שהוטמעו ב-FLAX פועלים מחוץ לקופסה ב-Cloud TPU v5e. בקטע הזה מוסבר איך להפעיל מודלים פופולריים.

אימון של ViT ב-Imagenette

במדריך הזה נלמד איך לאמן את מודל Vision Transformer‏ (ViT) מ-HuggingFace באמצעות מערך הנתונים Imagenette של Fast AI ב-Cloud TPU v5e.

מודל ViT היה הראשון שאומן בהצלחה על מקודד Transformer ב-ImageNet עם תוצאות מצוינות בהשוואה לרשתות קונבולוציה. מידע נוסף זמין במאמר סקירה כללית על ViT.

הגדרה

  1. יוצרים משתני סביבה:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    תיאורים של משתני סביבה

    • PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.
    • TPU_NAME: השם של ה-TPU.
    • ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU.
    • ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU.
    • RUNTIME_VERSION: גרסת התוכנה של Cloud TPU.
    • SERVICE_ACCOUNT: כתובת האימייל של חשבון השירות. אפשר למצוא אותו בדף Service Accounts במסוף Google Cloud .

      לדוגמה: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    • QUEUED_RESOURCE_ID: מזהה הטקסט שהמשתמש הקצה לבקשת המשאב שנוספה לתור.

  2. יצירת משאב TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    תוכלו להתחבר ב-SSH למכונת ה-TPU הווירטואלית ברגע שהמשאב בתור יהיה במצב ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    כשהמשאב בתור נמצא במצב ACTIVE, הפלט ייראה כך:

     state: ACTIVE
    
  3. מתקינים את JAX ואת הספרייה שלו:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. מורידים את המאגר של Hugging Face ומתקינים את הדרישות:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
    
  5. מורידים את מערך הנתונים Imagenette:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
    

אימון המודל

מאמנים את המודל עם מאגר שעבר מיפוי מראש בנפח 4GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'

מחיקת ה-TPU והמשאב בתור

בסיום הסשן, מוחקים את ה-TPU ואת המשאב בתור.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

תוצאות השוואה לשוק של ViT

סקריפט האימון הופעל ב-v5litepod-4, ב-v5litepod-16 וב-v5litepod-64. בטבלה הבאה מוצגים נתוני התפוקה עם סוגים שונים של מאיצים.

סוג המאיץ v5litepod-4 v5litepod-16 v5litepod-64
Epoch 3 3 3
גודל אצווה גלובלי 32 128 512
קצב העברת הנתונים (דוגמאות/שנייה) 263.40 429.34 470.71

אימון של מודל Diffusion על פוקימון

במדריך הזה תלמדו איך לאמן את מודל הדיפוזיה Stable Diffusion מ-HuggingFace באמצעות מערך הנתונים Pokémon ב-Cloud TPU v5e.

מודל Stable Diffusion הוא מודל סמוי של יצירת תמונות לפי טקסט שיוצר תמונות ריאליסטיות מכל קלט טקסט. מידע נוסף זמין במקורות המידע הבאים:

הגדרה

  1. מגדירים משתנה סביבה לשם של קטגוריית האחסון:

    export GCS_BUCKET_NAME=your_bucket_name
  2. מגדירים קטגוריית אחסון לפלט של המודל:

    gcloud storage buckets create gs://GCS_BUCKET_NAME \
        --project=your_project \
        --location=us-west1
  3. יוצרים משתני סביבה:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west1-c
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    תיאורים של משתני סביבה

    • PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.
    • TPU_NAME: השם של ה-TPU.
    • ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU.
    • ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU.
    • RUNTIME_VERSION: גרסת התוכנה של Cloud TPU.
    • SERVICE_ACCOUNT: כתובת האימייל של חשבון השירות. אפשר למצוא אותו בדף Service Accounts במסוף Google Cloud .

      לדוגמה: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    • QUEUED_RESOURCE_ID: מזהה הטקסט שהמשתמש הקצה לבקשת המשאב שנוספה לתור.

  4. יצירת משאב TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    תוכלו להתחבר ב-SSH ל-TPU VM כשהמשאב בתור יהיה במצב ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    כשהמשאב שנמצא בתור הוא במצב ACTIVE, הפלט ייראה כך:

     state: ACTIVE
    
  5. מתקינים את JAX ואת הספרייה שלה.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  6. מורידים את המאגר של HuggingFace ומתקינים את הדרישות.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
         --project=${PROJECT_ID} \
         --zone=${ZONE} \
         --worker=all \
         --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
    

אימון המודל

מאמנים את המודל עם מאגר שעבר מיפוי מראש בנפח 4GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
    git clone https://github.com/google/maxdiffusion
    cd maxdiffusion
    pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    pip3 install -r requirements.txt
    pip3 install .
    pip3 install gcsfs
    export LIBTPU_INIT_ARGS=''
    python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
    jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
    per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
    output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"

הסרת המשאבים

בסיום הסשן, מוחקים את ה-TPU, את המשאב בתור ואת קטגוריה של Cloud Storage.

  1. מחיקת TPU:

    gcloud compute tpus tpu-vm delete ${TPU_NAME} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  2. מחיקת המשאב שנוסף לתור:

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  3. מוחקים את הקטגוריה של Cloud Storage:

    gcloud storage rm -r gs://${GCS_BUCKET_NAME}
    

תוצאות השוואה לשוק של דיפוזיה

סקריפט האימון הופעל ב-v5litepod-4, ב-v5litepod-16 וב-v5litepod-64. בטבלה הבאה מוצגים נתוני התפוקה.

סוג המאיץ v5litepod-4 v5litepod-16 v5litepod-64
שלב האימון 1500 1500 1500
גודל אצווה גלובלי 32 64 128
קצב העברת הנתונים (דוגמאות/שנייה) 36.53 43.71 49.36

PyTorch/XLA

בקטעים הבאים מפורטות דוגמאות לאימון מודלים של PyTorch/XLA ב-TPU v5e.

אימון של ResNet באמצעות זמן הריצה של PJRT

מתבצעת מיגרציה של PyTorch/XLA מ-XRT ל-PjRt מ-PyTorch 2.0 ואילך. בהמשך מופיעות ההוראות המעודכנות להגדרת v5e עבור עומסי עבודה של אימון PyTorch/XLA.

הגדרה
  1. יוצרים משתני סביבה:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    תיאורים של משתני סביבה

    • PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.
    • TPU_NAME: השם של ה-TPU.
    • ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU.
    • ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU.
    • RUNTIME_VERSION: גרסת התוכנה של Cloud TPU.
    • SERVICE_ACCOUNT: כתובת האימייל של חשבון השירות. אפשר למצוא אותו בדף Service Accounts במסוף Google Cloud .

      לדוגמה: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    • QUEUED_RESOURCE_ID: מזהה הטקסט שהמשתמש הקצה לבקשת המשאב שנוספה לתור.

  2. יצירת משאב TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    תוכלו להתחבר ב-SSH ל-TPU VM אחרי שה-QueuedResource יהיה במצב ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    כשהמשאב שנמצא בתור הוא במצב ACTIVE, הפלט ייראה כך:

     state: ACTIVE
    
  3. התקנת יחסי תלות ספציפיים של Torch/XLA

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project=${PROJECT_ID} \
      --zone=${ZONE} \
      --worker=all \
      --command='
         sudo apt-get update -y
         sudo apt-get install libomp5 -y
         pip3 install mkl mkl-include
         pip3 install tf-nightly tb-nightly tbp-nightly
         pip3 install numpy
         sudo apt-get install libopenblas-dev -y
         pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

    מחליפים את PYTORCH_VERSION בגרסה של PyTorch שבה רוצים להשתמש. המאפיין PYTORCH_VERSION משמש לציון אותה גרסה של PyTorch/XLA. מומלצת גרסה 2.6.0.

    מידע נוסף על גרסאות של PyTorch ו-PyTorch/XLA זמין במאמרים PyTorch - Get Started ו-PyTorch/XLA releases.

    מידע נוסף על התקנת PyTorch/XLA זמין במאמר בנושא התקנת PyTorch/XLA.

אימון מודל ResNet
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      date
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export XLA_USE_BF16=1
      export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      git clone https://github.com/pytorch/xla.git
      cd xla/
      git checkout release-r2.6
      python3 test/test_train_mp_imagenet.py --model=resnet50  --fake_data --num_epochs=1 --num_workers=16  --log_steps=300 --batch_size=64 --profile'

מחיקת ה-TPU והמשאב בתור

בסיום הסשן, מוחקים את ה-TPU ואת המשאב שמוכנס לתור.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
תוצאת ההשוואה לשוק

בטבלה הבאה מוצגים נתוני התפוקה של בדיקת הביצועים.

סוג המאיץ קצב העברת נתונים (דוגמאות לשנייה)
v5litepod-4 ‫4,240 ex/s
v5litepod-16 ‫10,810 ex/s
v5litepod-64 ‫46,154 ex/s

אימון של ViT ב-v5e

במדריך הזה נסביר איך להריץ VIT בגרסה v5e באמצעות מאגר HuggingFace ב-PyTorch/XLA בקבוצת הנתונים cifar10.

הגדרה

  1. יוצרים משתני סביבה:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    תיאורים של משתני סביבה

    • PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.
    • TPU_NAME: השם של ה-TPU.
    • ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU.
    • ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU.
    • RUNTIME_VERSION: גרסת התוכנה של Cloud TPU.
    • SERVICE_ACCOUNT: כתובת האימייל של חשבון השירות. אפשר למצוא אותו בדף Service Accounts במסוף Google Cloud .

      לדוגמה: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    • QUEUED_RESOURCE_ID: מזהה הטקסט שהמשתמש הקצה לבקשת המשאב שנוספה לתור.

  2. יצירת משאב TPU:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    תוכלו להתחבר ב-SSH ל-TPU VM אחרי שה-QueuedResource יהיה במצב ACTIVE:

     gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    כשהמשאב בתור נמצא במצב ACTIVE, הפלט ייראה כך:

     state: ACTIVE
    
  3. התקנת יחסי תלות של PyTorch/XLA

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip3 install mkl mkl-include
      pip3 install tf-nightly tb-nightly tbp-nightly
      pip3 install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
      pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

    מחליפים את PYTORCH_VERSION בגרסה של PyTorch שבה רוצים להשתמש. המאפיין PYTORCH_VERSION משמש לציון אותה גרסה של PyTorch/XLA. מומלצת גרסה 2.6.0.

    מידע נוסף על גרסאות של PyTorch ו-PyTorch/XLA זמין במאמרים PyTorch - Get Started ו-PyTorch/XLA releases.

    מידע נוסף על התקנת PyTorch/XLA זמין במאמר בנושא התקנת PyTorch/XLA.

  4. מורידים את המאגר של HuggingFace ומתקינים את הדרישות.

       gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="
          git clone https://github.com/suexu1025/transformers.git vittransformers; \
          cd vittransformers; \
          pip3 install .; \
          pip3 install datasets; \
          wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
    

אימון המודל

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export TF_CPP_MIN_LOG_LEVEL=0
      export XLA_USE_BF16=1
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      cd vittransformers
      python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
      --remove_unused_columns=False \
      --label_names=pixel_values \
      --mask_ratio=0.75 \
      --norm_pix_loss=True \
      --do_train=true \
      --do_eval=true \
      --base_learning_rate=1.5e-4 \
      --lr_scheduler_type=cosine \
      --weight_decay=0.05 \
      --num_train_epochs=3 \
      --warmup_ratio=0.05 \
      --per_device_train_batch_size=8 \
      --per_device_eval_batch_size=8 \
      --logging_strategy=steps \
      --logging_steps=30 \
      --evaluation_strategy=epoch \
      --save_strategy=epoch \
      --load_best_model_at_end=True \
      --save_total_limit=3 \
      --seed=1337 \
      --output_dir=MAE \
      --overwrite_output_dir=true \
      --logging_dir=./tensorboard-metrics \
      --tpu_metrics_debug=true'

מחיקת ה-TPU והמשאב בתור

בסיום הסשן, מוחקים את ה-TPU ואת המשאב שמוכנס לתור.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

תוצאת ההשוואה לשוק

בטבלה הבאה מוצגים נתוני התפוקה של מדד ההשוואה לסוגים שונים של מאיצים.

v5litepod-4 v5litepod-16 v5litepod-64
Epoch 3 3 3
גודל אצווה גלובלי 32 128 512
קצב העברת הנתונים (דוגמאות/שנייה) 201 657 2,844