Menyajikan Gemma menggunakan TPU di GKE dengan JetStream

Tutorial ini menunjukkan cara menayangkan model bahasa besar (LLM) Gemma menggunakan Unit Pemrosesan Tensor (TPU) di Google Kubernetes Engine (GKE). Anda men-deploy container siap pakai dengan JetStream dan MaxText ke GKE. Anda juga mengonfigurasi GKE untuk memuat bobot Gemma 7B dari Cloud Storage saat runtime.

Tutorial ini ditujukan untuk Engineer machine learning (ML), Admin dan operator platform, serta Spesialis data dan AI yang tertarik untuk menggunakan kemampuan orkestrasi container Kubernetes dalam menayangkan LLM. Untuk mempelajari lebih lanjut peran umum dan contoh tugas yang kami referensikan dalam Google Cloud konten, lihat Peran dan tugas pengguna GKE umum.

Sebelum membaca halaman ini, pastikan Anda memahami hal-hal berikut:

Latar belakang

Bagian ini menjelaskan teknologi utama yang digunakan dalam tutorial ini.

Gemma

Gemma adalah serangkaian model kecerdasan buatan (AI) generatif yang ringan dan tersedia secara terbuka yang dirilis dengan lisensi terbuka. Model AI ini tersedia untuk dijalankan di aplikasi, hardware, perangkat seluler, atau layanan yang dihosting. Anda dapat menggunakan model Gemma untuk pembuatan teks, tetapi Anda juga dapat menyesuaikan model ini untuk tugas khusus.

Untuk mempelajari lebih lanjut, lihat dokumentasi Gemma.

TPU

TPU adalah sirkuit terintegrasi khusus aplikasi (ASIC) yang dikembangkan khusus oleh Google dan digunakan untuk mempercepat model machine learning dan AI yang dibuat menggunakan framework seperti TensorFlow, PyTorch, dan JAX.

Tutorial ini membahas cara menayangkan model Gemma 7B. GKE men-deploy model pada node TPUv5e host tunggal dengan topologi TPU yang dikonfigurasi berdasarkan persyaratan model untuk menyajikan perintah dengan latensi rendah.

JetStream

JetStream adalah framework penayangan inferensi open source yang dikembangkan oleh Google. JetStream memungkinkan inferensi berperforma tinggi, throughput tinggi, dan dioptimalkan untuk memori di TPU dan GPU. Framework ini menyediakan pengoptimalan performa lanjutan, termasuk teknik pengelompokan berkelanjutan dan kuantisasi, untuk memfasilitasi deployment LLM. JetStream memungkinkan penayangan PyTorch/XLA dan JAX TPU mencapai performa yang optimal.

Untuk mempelajari lebih lanjut pengoptimalan ini, lihat repositori project JetStream PyTorch dan JetStream MaxText.

MaxText

MaxText adalah implementasi LLM JAX yang berperforma tinggi, skalabel, dan dapat disesuaikan, yang dibangun di atas library JAX open source seperti Flax, Orbax, dan Optax. Implementasi LLM khusus dekoder MaxText ditulis dalam Python. Hal ini memanfaatkan compiler XLA secara intensif untuk mencapai performa tinggi tanpa perlu membuat kernel kustom.

Untuk mempelajari lebih lanjut model dan ukuran parameter terbaru yang didukung MaxText, lihat repositori project MaxText.