אימון מודל באמצעות TPU v5e
עם טביעת רגל קטנה יותר של 256 שבבים לכל Pod, TPU v5e מותאם להיות מוצר בעל ערך גבוה לאימון, לכוונון עדין ולהצגה של טרנספורמציה, יצירת תמונות לפי טקסט ורשת עצבית קונבולוציונית (CNN). מידע נוסף על שימוש ב-Cloud TPU v5e להצגת מודלים זמין במאמר הסקת מסקנות באמצעות v5e.
מידע נוסף על חומרה והגדרות של Cloud TPU v5e TPU זמין במאמר בנושא TPU v5e.
קדימה, מתחילים
בקטעים הבאים מוסבר איך להתחיל להשתמש ב-TPU v5e.
בקשת מכסה
כדי להשתמש ב-TPU v5e לאימון, צריך מכסה. יש סוגים שונים של מכסות ל-TPU על פי דרישה, ל-TPU שמורים ולמכונות וירטואליות מסוג TPU Spot. אם אתם משתמשים ב-TPU v5e להסקת מסקנות, נדרשות מכסות נפרדות. מידע נוסף על מכסות זמין במאמר מכסות. כדי לבקש מכסת TPU v5e, פנו למחלקת המכירות של Cloud.
יצירת חשבון ופרויקט Google Cloud
כדי להשתמש ב-Cloud TPU, אתם צריכים חשבון ופרויקט ב- Google Cloud . מידע נוסף זמין במאמר הגדרת סביבת Cloud TPU.
יצירת Cloud TPU
השיטה המומלצת היא להקצות Cloud TPU v5es כמשאבים בתור באמצעות הפקודה queued-resource create. מידע נוסף זמין במאמר בנושא ניהול משאבים בתור.
אפשר גם להשתמש ב-Create Node API (gcloud compute tpus tpu-vm create) כדי להקצות Cloud TPU v5es. מידע נוסף זמין במאמר ניהול משאבי TPU.
מידע נוסף על ההגדרות הזמינות של v5e לאימון מופיע במאמר סוגי Cloud TPU v5e לאימון.
הגדרת ה-Framework
בקטע הזה מתואר תהליך ההגדרה הכללי לאימון מודלים בהתאמה אישית באמצעות JAX או PyTorch עם TPU v5e.
הוראות להגדרת הסקת מסקנות מופיעות במאמר מבוא להסקת מסקנות בגרסה v5e.
מגדירים כמה משתני סביבה:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id
הגדרה של JAX
אם יש לכם צורות של פרוסות עם יותר מ-8 שבבים, יהיו לכם כמה מכונות וירטואליות בפרוסה אחת. במקרה כזה, צריך להשתמש בדגל --worker=all כדי להריץ את ההתקנה בכל מכונות ה-TPU הווירטואליות בשלב אחד, בלי להשתמש ב-SSH כדי להתחבר לכל אחת מהן בנפרד:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
תיאורים של דגלי פקודות
-
TPU_NAME: מזהה הטקסט שהמשתמש הקצה ל-TPU שנוצר כשהוקצתה בקשת המשאב שהוכנסה לתור. -
PROJECT_ID: Google Cloud שם הפרויקט. משתמשים בפרויקט קיים או יוצרים פרויקט חדש בקטע הגדרת Google Cloud הפרויקט -
ZONE: מידע על התחומים הנתמכים זמין במסמך TPU regions and zones (אזורים ותחומים של TPU). -
worker: מכונת ה-TPU VM שיש לה גישה למעבדי ה-TPU הבסיסיים.
כדי לבדוק את מספר המכשירים, מריצים את הפקודה הבאה (הפלט שמוצג כאן נוצר באמצעות פרוסת v5litepod-16). הקוד הזה בודק שהכול מותקן בצורה נכונה. הוא עושה זאת על ידי בדיקה של JAX כדי לוודא שהוא מזהה את ליבות ה-Tensor של Cloud TPU ויכול להריץ פעולות בסיסיות:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
הפלט ייראה בערך כך:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4
jax.device_count() מציג את המספר הכולל של הצ'יפים בפלח הנתון.
jax.local_device_count() מציין את מספר השבבים שאפשר לגשת אליהם ממכונה וירטואלית אחת בפלח הזה.
# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
הפלט ייראה בערך כך:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
כדי להתחיל להשתמש ב-JAX לאימון מודלים בגרסה v5e, כדאי לנסות את המדריכים ל-JAX שמופיעים במסמך הזה.
הגדרה של PyTorch
הערה: גרסה v5e תומכת רק ב-PJRT runtime ו-PyTorch 2.1+ ישתמש ב-PJRT כ-runtime ברירת המחדל לכל גרסאות ה-TPU.
בקטע הזה מוסבר איך להתחיל להשתמש ב-PJRT בגרסה v5e עם PyTorch/XLA באמצעות פקודות לכל העובדים.
התקנת יחסי תלות
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip install mkl mkl-include pip install tf-nightly tb-nightly tbp-nightly pip install numpy sudo apt-get install libopenblas-dev -y pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
מחליפים את PYTORCH_VERSION בגרסה של PyTorch שבה רוצים להשתמש.
המאפיין PYTORCH_VERSION משמש לציון אותה גרסה של PyTorch/XLA. מומלצת גרסה 2.6.0.
מידע נוסף על גרסאות של PyTorch ו-PyTorch/XLA זמין במאמרים PyTorch - Get Started ו-PyTorch/XLA releases.
מידע נוסף על התקנת PyTorch/XLA זמין במאמר בנושא התקנת PyTorch/XLA.
אם מתקבלת שגיאה כשמנסים להתקין את קובצי ה-wheel של torch, torch_xla או torchvision, כמו pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end
or semicolon (after name and no valid version specifier) torch==nightly+20230222, צריך לשנמך את הגרסה באמצעות הפקודה הבאה:
pip3 install setuptools==62.1.0
הרצת סקריפט באמצעות PJRT
unset LD_PRELOAD
הדוגמה הבאה מציגה שימוש בסקריפט Python לביצוע חישוב במכונה וירטואלית v5e:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
unset LD_PRELOAD
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
יווצר פלט שדומה לזה:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
כדי להתחיל באימון של גרסה 5e באמצעות PyTorch, אפשר לנסות את המדריכים של PyTorch שמופיעים במסמך הזה.
בסיום הסשן, מוחקים את ה-TPU ואת המשאב שמוכנס לתור. כדי למחוק משאב בתור, צריך למחוק את הפרוסה ואז את המשאב בתור בשני שלבים:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
אפשר להשתמש בשני השלבים האלה גם כדי להסיר בקשות למשאבים שנמצאות בתור ומצבן הוא FAILED.
דוגמאות ל-JAX/FLAX
בקטעים הבאים מפורטות דוגמאות לאימון מודלים של JAX ו-FLAX ב-TPU v5e.
אימון ImageNet בגרסה v5e
במדריך הזה מוסבר איך לאמן את ImageNet בגרסה v5e באמצעות נתוני קלט פיקטיביים. אם רוצים להשתמש בנתונים אמיתיים, אפשר לעיין בקובץ ה-README ב-GitHub.
הגדרה
יוצרים משתני סביבה:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
תיאורים של משתני סביבה
PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.-
TPU_NAME: השם של ה-TPU. -
ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU. -
ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU. -
RUNTIME_VERSION: גרסת התוכנה של Cloud TPU. SERVICE_ACCOUNT: כתובת האימייל של חשבון השירות. אפשר למצוא אותו בדף Service Accounts במסוף Google Cloud .לדוגמה:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com-
QUEUED_RESOURCE_ID: מזהה הטקסט שהמשתמש הקצה לבקשת המשאב שנוספה לתור.
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}תוכלו להתחבר ב-SSH ל-TPU VM כשהמשאב בתור יהיה במצב
ACTIVE:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}כשה-QueuedResource נמצא במצב
ACTIVE, הפלט ייראה כך:state: ACTIVE מתקינים את הגרסה החדשה ביותר של JAX ו-jaxlib:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'משכפלים את מודל ImageNet ומתקינים את הדרישות המתאימות:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"כדי ליצור נתונים מזויפים, המודל צריך מידע על המימדים של מערך הנתונים. אפשר לאסוף את הנתונים האלה מהמטא-נתונים של קבוצת הנתונים ImageNet:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
אימון המודל
אחרי שמשלימים את כל השלבים הקודמים, אפשר לאמן את המודל.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"
מחיקת ה-TPU והמשאב בתור
בסיום הסשן, מוחקים את ה-TPU ואת המשאב שמוכנס לתור.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
מודלים של Hugging Face FLAX
מודלים של Hugging Face שהוטמעו ב-FLAX פועלים מחוץ לקופסה ב-Cloud TPU v5e. בקטע הזה מוסבר איך להפעיל מודלים פופולריים.
אימון של ViT ב-Imagenette
במדריך הזה נלמד איך לאמן את מודל Vision Transformer (ViT) מ-HuggingFace באמצעות מערך הנתונים Imagenette של Fast AI ב-Cloud TPU v5e.
מודל ViT היה הראשון שאומן בהצלחה על מקודד Transformer ב-ImageNet עם תוצאות מצוינות בהשוואה לרשתות קונבולוציה. מידע נוסף זמין במאמר סקירה כללית על ViT.
הגדרה
יוצרים משתני סביבה:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
תיאורים של משתני סביבה
PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.-
TPU_NAME: השם של ה-TPU. -
ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU. -
ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU. -
RUNTIME_VERSION: גרסת התוכנה של Cloud TPU. SERVICE_ACCOUNT: כתובת האימייל של חשבון השירות. אפשר למצוא אותו בדף Service Accounts במסוף Google Cloud .לדוגמה:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com-
QUEUED_RESOURCE_ID: מזהה הטקסט שהמשתמש הקצה לבקשת המשאב שנוספה לתור.
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}תוכלו להתחבר ב-SSH למכונת ה-TPU הווירטואלית ברגע שהמשאב בתור יהיה במצב
ACTIVE:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}כשהמשאב בתור נמצא במצב
ACTIVE, הפלט ייראה כך:state: ACTIVE מתקינים את JAX ואת הספרייה שלו:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'מורידים את המאגר של Hugging Face ומתקינים את הדרישות:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'מורידים את מערך הנתונים Imagenette:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
אימון המודל
מאמנים את המודל עם מאגר שעבר מיפוי מראש בנפח 4GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'
מחיקת ה-TPU והמשאב בתור
בסיום הסשן, מוחקים את ה-TPU ואת המשאב בתור.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
תוצאות השוואה לשוק של ViT
סקריפט האימון הופעל ב-v5litepod-4, ב-v5litepod-16 וב-v5litepod-64. בטבלה הבאה מוצגים נתוני התפוקה עם סוגים שונים של מאיצים.
| סוג המאיץ | v5litepod-4 | v5litepod-16 | v5litepod-64 |
| Epoch | 3 | 3 | 3 |
| גודל אצווה גלובלי | 32 | 128 | 512 |
| קצב העברת הנתונים (דוגמאות/שנייה) | 263.40 | 429.34 | 470.71 |
אימון של מודל Diffusion על פוקימון
במדריך הזה תלמדו איך לאמן את מודל הדיפוזיה Stable Diffusion מ-HuggingFace באמצעות מערך הנתונים Pokémon ב-Cloud TPU v5e.
מודל Stable Diffusion הוא מודל סמוי של יצירת תמונות לפי טקסט שיוצר תמונות ריאליסטיות מכל קלט טקסט. מידע נוסף זמין במקורות המידע הבאים:
הגדרה
מגדירים משתנה סביבה לשם של קטגוריית האחסון:
export GCS_BUCKET_NAME=your_bucket_name
מגדירים קטגוריית אחסון לפלט של המודל:
gcloud storage buckets create gs://GCS_BUCKET_NAME \ --project=your_project \ --location=us-west1
יוצרים משתני סביבה:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west1-c export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
תיאורים של משתני סביבה
PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.-
TPU_NAME: השם של ה-TPU. -
ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU. -
ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU. -
RUNTIME_VERSION: גרסת התוכנה של Cloud TPU. SERVICE_ACCOUNT: כתובת האימייל של חשבון השירות. אפשר למצוא אותו בדף Service Accounts במסוף Google Cloud .לדוגמה:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com-
QUEUED_RESOURCE_ID: מזהה הטקסט שהמשתמש הקצה לבקשת המשאב שנוספה לתור.
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}תוכלו להתחבר ב-SSH ל-TPU VM כשהמשאב בתור יהיה במצב
ACTIVE:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}כשהמשאב שנמצא בתור הוא במצב
ACTIVE, הפלט ייראה כך:state: ACTIVE מתקינים את JAX ואת הספרייה שלה.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'מורידים את המאגר של HuggingFace ומתקינים את הדרישות.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
אימון המודל
מאמנים את המודל עם מאגר שעבר מיפוי מראש בנפח 4GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
git clone https://github.com/google/maxdiffusion
cd maxdiffusion
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
pip3 install gcsfs
export LIBTPU_INIT_ARGS=''
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"
הסרת המשאבים
בסיום הסשן, מוחקים את ה-TPU, את המשאב בתור ואת קטגוריה של Cloud Storage.
מחיקת TPU:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quietמחיקת המשאב שנוסף לתור:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quietמוחקים את הקטגוריה של Cloud Storage:
gcloud storage rm -r gs://${GCS_BUCKET_NAME}
תוצאות השוואה לשוק של דיפוזיה
סקריפט האימון הופעל ב-v5litepod-4, ב-v5litepod-16 וב-v5litepod-64. בטבלה הבאה מוצגים נתוני התפוקה.
| סוג המאיץ | v5litepod-4 | v5litepod-16 | v5litepod-64 |
| שלב האימון | 1500 | 1500 | 1500 |
| גודל אצווה גלובלי | 32 | 64 | 128 |
| קצב העברת הנתונים (דוגמאות/שנייה) | 36.53 | 43.71 | 49.36 |
PyTorch/XLA
בקטעים הבאים מפורטות דוגמאות לאימון מודלים של PyTorch/XLA ב-TPU v5e.
אימון של ResNet באמצעות זמן הריצה של PJRT
מתבצעת מיגרציה של PyTorch/XLA מ-XRT ל-PjRt מ-PyTorch 2.0 ואילך. בהמשך מופיעות ההוראות המעודכנות להגדרת v5e עבור עומסי עבודה של אימון PyTorch/XLA.
הגדרה
יוצרים משתני סביבה:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
תיאורים של משתני סביבה
PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.-
TPU_NAME: השם של ה-TPU. -
ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU. -
ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU. -
RUNTIME_VERSION: גרסת התוכנה של Cloud TPU. SERVICE_ACCOUNT: כתובת האימייל של חשבון השירות. אפשר למצוא אותו בדף Service Accounts במסוף Google Cloud .לדוגמה:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com-
QUEUED_RESOURCE_ID: מזהה הטקסט שהמשתמש הקצה לבקשת המשאב שנוספה לתור.
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}תוכלו להתחבר ב-SSH ל-TPU VM אחרי שה-QueuedResource יהיה במצב
ACTIVE:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}כשהמשאב שנמצא בתור הוא במצב
ACTIVE, הפלט ייראה כך:state: ACTIVE התקנת יחסי תלות ספציפיים של Torch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
מחליפים את
PYTORCH_VERSIONבגרסה של PyTorch שבה רוצים להשתמש. המאפייןPYTORCH_VERSIONמשמש לציון אותה גרסה של PyTorch/XLA. מומלצת גרסה 2.6.0.מידע נוסף על גרסאות של PyTorch ו-PyTorch/XLA זמין במאמרים PyTorch - Get Started ו-PyTorch/XLA releases.
מידע נוסף על התקנת PyTorch/XLA זמין במאמר בנושא התקנת PyTorch/XLA.
אימון מודל ResNet
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
date
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export XLA_USE_BF16=1
export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
git clone https://github.com/pytorch/xla.git
cd xla/
git checkout release-r2.6
python3 test/test_train_mp_imagenet.py --model=resnet50 --fake_data --num_epochs=1 --num_workers=16 --log_steps=300 --batch_size=64 --profile'
מחיקת ה-TPU והמשאב בתור
בסיום הסשן, מוחקים את ה-TPU ואת המשאב שמוכנס לתור.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
תוצאת ההשוואה לשוק
בטבלה הבאה מוצגים נתוני התפוקה של בדיקת הביצועים.
| סוג המאיץ | קצב העברת נתונים (דוגמאות לשנייה) |
| v5litepod-4 | 4,240 ex/s |
| v5litepod-16 | 10,810 ex/s |
| v5litepod-64 | 46,154 ex/s |
אימון של ViT ב-v5e
במדריך הזה נסביר איך להריץ VIT בגרסה v5e באמצעות מאגר HuggingFace ב-PyTorch/XLA בקבוצת הנתונים cifar10.
הגדרה
יוצרים משתני סביבה:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
תיאורים של משתני סביבה
PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.-
TPU_NAME: השם של ה-TPU. -
ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU. -
ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU. -
RUNTIME_VERSION: גרסת התוכנה של Cloud TPU. SERVICE_ACCOUNT: כתובת האימייל של חשבון השירות. אפשר למצוא אותו בדף Service Accounts במסוף Google Cloud .לדוגמה:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com-
QUEUED_RESOURCE_ID: מזהה הטקסט שהמשתמש הקצה לבקשת המשאב שנוספה לתור.
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}תוכלו להתחבר ב-SSH ל-TPU VM אחרי שה-QueuedResource יהיה במצב
ACTIVE:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}כשהמשאב בתור נמצא במצב
ACTIVE, הפלט ייראה כך:state: ACTIVE התקנת יחסי תלות של PyTorch/XLA
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
מחליפים את
PYTORCH_VERSIONבגרסה של PyTorch שבה רוצים להשתמש. המאפייןPYTORCH_VERSIONמשמש לציון אותה גרסה של PyTorch/XLA. מומלצת גרסה 2.6.0.מידע נוסף על גרסאות של PyTorch ו-PyTorch/XLA זמין במאמרים PyTorch - Get Started ו-PyTorch/XLA releases.
מידע נוסף על התקנת PyTorch/XLA זמין במאמר בנושא התקנת PyTorch/XLA.
מורידים את המאגר של HuggingFace ומתקינים את הדרישות.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=" git clone https://github.com/suexu1025/transformers.git vittransformers; \ cd vittransformers; \ pip3 install .; \ pip3 install datasets; \ wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
אימון המודל
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export TF_CPP_MIN_LOG_LEVEL=0
export XLA_USE_BF16=1
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
cd vittransformers
python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
--remove_unused_columns=False \
--label_names=pixel_values \
--mask_ratio=0.75 \
--norm_pix_loss=True \
--do_train=true \
--do_eval=true \
--base_learning_rate=1.5e-4 \
--lr_scheduler_type=cosine \
--weight_decay=0.05 \
--num_train_epochs=3 \
--warmup_ratio=0.05 \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=8 \
--logging_strategy=steps \
--logging_steps=30 \
--evaluation_strategy=epoch \
--save_strategy=epoch \
--load_best_model_at_end=True \
--save_total_limit=3 \
--seed=1337 \
--output_dir=MAE \
--overwrite_output_dir=true \
--logging_dir=./tensorboard-metrics \
--tpu_metrics_debug=true'
מחיקת ה-TPU והמשאב בתור
בסיום הסשן, מוחקים את ה-TPU ואת המשאב שמוכנס לתור.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
תוצאת ההשוואה לשוק
בטבלה הבאה מוצגים נתוני התפוקה של מדד ההשוואה לסוגים שונים של מאיצים.
| v5litepod-4 | v5litepod-16 | v5litepod-64 | |
| Epoch | 3 | 3 | 3 |
| גודל אצווה גלובלי | 32 | 128 | 512 |
| קצב העברת הנתונים (דוגמאות/שנייה) | 201 | 657 | 2,844 |