הרצת קוד JAX בפרוסות TPU
לפני שמריצים את הפקודות במסמך הזה, חשוב לוודא שפעלתם לפי ההוראות שבמאמר הגדרת חשבון ופרויקט Cloud TPU.
אחרי שהקוד של JAX פועל על לוח TPU יחיד, אפשר להגדיל את קנה המידה של הקוד על ידי הפעלתו בפרוסת TPU. פרוסות TPU הן כמה לוחות TPU שמחוברים זה לזה באמצעות חיבורי רשת ייעודיים במהירות גבוהה. המסמך הזה הוא מבוא להרצת קוד JAX בפרוסות TPU. למידע מעמיק יותר, אפשר לעיין במאמר בנושא שימוש ב-JAX בסביבות מרובות מארחים ומרובות תהליכים.
התפקידים הנדרשים
כדי לקבל את ההרשאות שדרושות ליצירת TPU ולהתחבר אליו באמצעות SSH, צריך לבקש מהאדמין להקצות לכם בפרויקט את תפקידי ה-IAM הבאים:
-
אדמין TPU (
roles/tpu.admin) -
משתמש בחשבון שירות (
roles/iam.serviceAccountUser) -
צפייה ב-Compute (
roles/compute.viewer)
להסבר על מתן תפקידים, ראו איך מנהלים את הגישה ברמת הפרויקט, התיקייה והארגון.
יכול להיות שאפשר לקבל את ההרשאות הנדרשות גם באמצעות תפקידים בהתאמה אישית או תפקידים מוגדרים מראש.
יצירת פרוסת Cloud TPU
יוצרים כמה משתני סביבה:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5litepod-32 export RUNTIME_VERSION=v2-alpha-tpuv5-lite
תיאורים של משתני סביבה
PROJECT_ID: מזהה הפרויקט ב- Google Cloud . משתמשים בפרויקט קיים או יוצרים פרויקט חדש.-
TPU_NAME: השם של ה-TPU. -
ZONE: האזור שבו תיצור את מכונת ה-TPU הווירטואלית. מידע נוסף על אזורים נתמכים זמין במאמר אזורים ותחומים של TPU. -
ACCELERATOR_TYPE: סוג המאיץ מציין את הגרסה והגודל של Cloud TPU שרוצים ליצור. מידע נוסף על סוגי המאיצים הנתמכים בכל גרסת TPU זמין במאמר בנושא גרסאות TPU. -
RUNTIME_VERSION: גרסת התוכנה של Cloud TPU.
יוצרים פרוסת TPU באמצעות הפקודה
gcloud. לדוגמה, כדי ליצור פרוסת v5litepod-32, משתמשים בפקודה הבאה:$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
התקנת JAX בפלח
אחרי שיוצרים את פרוסת ה-TPU, צריך להתקין את JAX בכל המארחים בפרוסת ה-TPU. אפשר לעשות את זה באמצעות הפקודה gcloud compute tpus tpu-vm ssh עם הפרמטרים --worker=all ו---commamnd.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
הרצת קוד JAX בפרוסה
כדי להריץ קוד JAX בפלח TPU, צריך להריץ את הקוד בכל מארח בפלח ה-TPU. השיחה ב-jax.device_count() מפסיקה להגיב עד שמתקשרים לכל מארח בפלח. בדוגמה הבאה אפשר לראות איך מריצים חישוב JAX על פרוסת TPU.
הכנת הקוד
נדרשת גרסה gcloud ומעלה (לפקודה scp).
משתמשים בפקודה gcloud --version כדי לבדוק את הגרסה של gcloud, ומריצים את הפקודה gcloud components upgrade אם צריך.
יוצרים קובץ בשם example.py עם הקוד הבא:
import jax
# Initialize the slice
jax.distributed.initialize()
# The total number of TPU cores in the slice
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
העתקה של example.py לכל מכונות ה-VM של העובדים ב-TPU בפרוסת ה-TPU
$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \ --worker=all \ --zone=${ZONE} \ --project=${PROJECT_ID}
אם לא השתמשתם בעבר בפקודה scp, יכול להיות שתופיע שגיאה דומה לזו:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
כדי לפתור את השגיאה, מריצים את הפקודה ssh-add כמו שהיא מוצגת בהודעת השגיאה ומריצים אותה מחדש.
הרצת הקוד בפרוסה
מפעילים את התוכנה example.py בכל מכונת VM:
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="python3 ./example.py"
פלט (נוצר באמצעות פרוסת v5litepod-32):
global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]
הסרת המשאבים
כשמסיימים להשתמש במכונת ה-TPU הווירטואלית, פועלים לפי השלבים הבאים כדי לנקות את המשאבים.
מוחקים את המשאבים של Cloud TPU ו-Compute Engine.
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID}
כדי לוודא שהמשאבים נמחקו, מריצים את הפקודה
gcloud compute tpus execution-groups list. יכול להיות שיחלפו כמה דקות עד שהמחיקה תסתיים. הפלט מהפקודה הבאה לא אמור לכלול אף אחד מהמשאבים שנוצרו במדריך הזה:$ gcloud compute tpus tpu-vm list --zone=${ZONE} \ --project=${PROJECT_ID}