הפעלת חישוב במכונה וירטואלית של Cloud TPU באמצעות JAX
במאמר הזה מוסבר בקצרה איך עובדים עם JAX ו-Cloud TPU.
לפני שמתחילים
לפני שמריצים את הפקודות במאמר הזה, צריך ליצור Google Cloudחשבון, להתקין את Google Cloud CLI ולהגדיר את הפקודה gcloud. מידע נוסף זמין במאמר בנושא הגדרת סביבת Cloud TPU.
התפקידים הנדרשים
כדי לקבל את ההרשאות שדרושות ליצירת TPU ולהתחבר אליו באמצעות SSH, צריך לבקש מהאדמין להקצות לכם בפרויקט את תפקידי ה-IAM הבאים:
-
אדמין TPU (
roles/tpu.admin) -
משתמש בחשבון שירות (
roles/iam.serviceAccountUser) -
צפייה ב-Compute (
roles/compute.viewer)
להסבר על מתן תפקידים, ראו איך מנהלים את הגישה ברמת הפרויקט, התיקייה והארגון.
יכול להיות שאפשר לקבל את ההרשאות הנדרשות גם באמצעות תפקידים בהתאמה אישית או תפקידים מוגדרים מראש.
יצירת מכונת Cloud TPU וירטואלית באמצעות gcloud
כדי להקל על השימוש בפקודות, אפשר להגדיר כמה משתני סביבה.
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east5-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
תיאורים של משתני סביבה
PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.-
TPU_NAME: השם של ה-TPU. -
ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU. -
ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU. -
RUNTIME_VERSION: גרסת התוכנה של Cloud TPU.
כדי ליצור מכונת TPU, מריצים את הפקודה הבאה מ-Cloud Shell או מהטרמינל של המחשב שבו מותקן Google Cloud CLI.
$ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
התחברות ל-Cloud TPU VM
מתחברים למכונת ה-TPU הווירטואלית באמצעות SSH באמצעות הפקודה הבאה:
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
אם לא הצלחתם להתחבר למכונת TPU וירטואלית באמצעות SSH, יכול להיות שלמכונת ה-TPU הווירטואלית אין כתובת IP חיצונית. כדי לגשת למכונת TPU וירטואלית ללא כתובת IP חיצונית, פועלים לפי ההוראות במאמר התחברות למכונת TPU וירטואלית ללא כתובת IP ציבורית.
התקנת JAX במכונת ה-VM של Cloud TPU
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
בדיקת המערכת
מוודאים של-JAX יש גישה ל-TPU והיא יכולה להריץ פעולות בסיסיות:
מפעילים את המתרגם של Python 3:
(vm)$ python3>>> import jax
הצגת מספר ליבות ה-TPU הזמינות:
>>> jax.device_count()
מוצג מספר ליבות ה-TPU. מספר ליבות ה-TPU שמוצג תלוי בגרסת ה-TPU שבה אתם משתמשים. מידע נוסף זמין במאמר בנושא גרסאות TPU.
ביצוע חישוב
>>> jax.numpy.add(1, 1)
התוצאה של הפעולה numpy add מוצגת:
פלט מהפקודה:
Array(2, dtype=int32, weak_type=True)
יציאה ממתורגמן Python
>>> exit()
הרצת קוד JAX במכונה וירטואלית של TPU
עכשיו אפשר להריץ כל קוד JAX שרוצים. הדוגמאות של Flax הן מקום מצוין להתחיל בו כדי להריץ מודלים סטנדרטיים של למידת מכונה ב-JAX. לדוגמה, כדי לאמן רשת בסיסית של MNIST:
מתקינים את יחסי התלות של דוגמאות Flax:
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
מתקינים את Flax:
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
מריצים את סקריפט האימון של Flax MNIST:
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
הסקריפט מוריד את מערך הנתונים ומתחיל באימון. פלט הסקריפט אמור להיראות כך:
I0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
הסרת המשאבים
כדי לא לצבור חיובים לחשבון Google Cloud על המשאבים שבהם השתמשתם בדף הזה:
כשמסיימים להשתמש במכונת ה-TPU הווירטואלית, מבצעים את השלבים הבאים כדי לנקות את המשאבים.
אם עדיין לא עשיתם זאת, מתנתקים ממופע Cloud TPU:
(vm)$ exit
ההנחיה אמורה להיות עכשיו username@projectname, כדי להראות שאתם ב-Cloud Shell.
כדי למחוק את Cloud TPU:
$ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
מריצים את הפקודה הבאה כדי לוודא שהמשאבים נמחקו. מוודאים שכרטיס ה-TPU לא מופיע יותר ברשימה. תהליך המחיקה עשוי להימשך כמה דקות.
$ gcloud compute tpus tpu-vm list \ --zone=$ZONE
הערות לגבי ביצועים
אלה כמה פרטים חשובים שרלוונטיים במיוחד לשימוש ב-TPU ב-JAX.
מרווח פנימי
אחת הסיבות הנפוצות ביותר לביצועים איטיים ב-TPU היא הוספה של ריפוד לא מכוון:
- מערכים ב-Cloud TPU מחולקים למקטעים. המשמעות היא שצריך להוסיף ריפוד לאחד המאפיינים כך שיהיה כפולה של 8, ולמאפיין אחר כך שיהיה כפולה של 128.
- הביצועים של יחידת הכפל המטריציוני הם הכי טובים כשמשתמשים בזוגות של מטריצות גדולות שבהן הצורך בריפוד הוא מינימלי.
סוג הנתונים bfloat16
כברירת מחדל, הכפלת מטריצות ב-JAX ב-TPU משתמשת ב-bfloat16 עם צבירה של float32. אפשר לשלוט בזה באמצעות הארגומנט precision בקריאות רלוונטיות לפונקציות jax.numpy (matmul, dot, einsum וכו'). הקפידו במיוחד על הדברים הבאים:
-
precision=jax.lax.Precision.DEFAULT: שימוש בדיוק מעורב של bfloat16 (המהיר ביותר) -
precision=jax.lax.Precision.HIGH: משתמש בכמה מעברים של MXU כדי להשיג רמת דיוק גבוהה יותר -
precision=jax.lax.Precision.HIGHEST: משתמש בעוד יותר מעברים של MXU כדי להשיג דיוק מלא של float32
בנוסף, JAX מוסיף את סוג הנתונים bfloat16, שאפשר להשתמש בו כדי להמיר מערכים באופן מפורש ל-bfloat16. לדוגמה, jax.numpy.array(x, dtype=jax.numpy.bfloat16).
המאמרים הבאים
מידע נוסף על Cloud TPU: