Menskalakan model di TPU

Dokumen ini memberikan ringkasan tentang cara menskalakan model bahasa: cara kerja TPU dan cara TPU berkomunikasi satu sama lain, cara LLM berjalan di hardware sungguhan, serta cara melakukan paralelisasi model Anda selama pelatihan dan inferensi agar berjalan secara efisien dalam skala besar. Kami memberikan informasi yang membantu Anda menilai seberapa mahal biaya pelatihan LLM, seberapa besar memori yang Anda butuhkan untuk menayangkan model, dan cara melakukan sharding model secara efektif di beberapa TPU.

Meskipun banyak deep learning yang kompleks, mengoptimalkan performa model Anda tidak harus demikian, bahkan dalam skala besar. Prinsip dasar berlaku di mana saja — mulai dari menangani satu akselerator hingga puluhan ribu — dan memahaminya memungkinkan Anda melakukan banyak hal yang berguna:

  • Perkirakan seberapa dekat bagian model Anda dengan titik optimal teoretisnya.
  • Membuat pilihan yang tepat tentang berbagai skema paralelisme pada skala yang berbeda (cara Anda membagi komputasi di beberapa perangkat).
  • Perkirakan biaya dan waktu yang diperlukan untuk melatih dan menjalankan model Transformer besar.
  • Merancang algoritma yang memanfaatkan arsitektur TPU.
  • Merancang arsitektur model yang didorong oleh pemahaman eksplisit tentang apa yang membatasi performa algoritma.

Prasyarat

Anda harus memiliki pemahaman dasar tentang LLM dan arsitektur Transformer, tetapi tidak harus memahami cara kerjanya dalam skala besar. Anda harus memahami dasar-dasar pelatihan LLM dan idealnya memiliki pemahaman dasar tentang JAX. Bacaan latar belakang yang berguna untuk arsitektur Transformer mencakup:

Setelah memahami prasyarat ini, Anda akan merasa nyaman memperkirakan skema paralelisme terbaik untuk model Transformer di platform TPU tertentu. Anda juga dapat memperkirakan durasi pelatihan dan inferensi.

Pentingnya penskalaan model

LLM dan sebagian besar model kecil saat ini berjalan sangat dekat dengan batas hardware sehingga pengembangan model mengharuskan Anda memikirkan efisiensi dalam skala besar. Peningkatan 20% pada tolok ukur tidak relevan jika harus mengorbankan efisiensi 20% pada batas atas. Arsitektur model yang menjanjikan sering kali gagal karena tidak dapat berjalan secara efisien dalam skala besar atau karena kurangnya upaya pengoptimalan untuk membuatnya berjalan secara efisien.

Tujuan penskalaan model adalah untuk dapat meningkatkan jumlah chip yang digunakan untuk pelatihan atau inferensi sekaligus mencapai peningkatan throughput yang proporsional dan linear. Hal ini dikenal sebagai penskalaan kuat. Meskipun menambahkan chip tambahan (paralelisme) biasanya mengurangi waktu komputasi, hal ini juga menimbulkan biaya komunikasi tambahan antar-chip. Jika komunikasi membutuhkan waktu lebih lama daripada komputasi, model akan terikat dengan komunikasi dan tidak dapat diskalakan dengan baik. Memahami hardware dengan cukup baik untuk mengantisipasi munculnya hambatan ini memungkinkan Anda mendesain atau mengonfigurasi ulang model untuk menghindari hambatan ini.

Bagian berikut memberikan ringkasan tentang cara menskalakan hardware TPU dan cara arsitektur Transformer berkembang. Informasi ini berguna bagi peneliti yang mendesain arsitektur baru dan engineer yang berupaya membuat LLM generasi saat ini berjalan dengan cepat.

Bagian 1: Konsep

Bagian ini menjelaskan analisis roofline dan faktor-faktor yang membatasi kemampuan model untuk melakukan penskalaan (komunikasi, komputasi, dan memori). Selanjutnya, kami menjelaskan cara kerja TPU, baik sebagai chip individual maupun — yang sangat penting — sebagai sistem yang saling terhubung dengan latensi dan bandwidth terbatas antar-chip.

  • Pengantar analisis roofline: Bagian ini menjelaskan cara memperkirakan seberapa cepat algoritma Anda akan berjalan berdasarkan batas komputasi, komunikasi, dan memori.
  • Operasi pada arsitektur TPU: Bagian ini menjelaskan arsitektur TPU, cara kerja berbagai modul hardware di TPU, dan pengaruhnya terhadap pelatihan dan penyajian model.
  • Sharding model untuk paralelisme multi-TPU: Bagian ini membahas sharding model dan paralelisme multi-TPU dengan menjelaskan perkalian matriks yang di-shard.

Bagian 2: Menskalakan Transformer

Penting untuk memahami setiap bagian arsitektur Transformer: ukuran pasti setiap matriks, tempat normalisasi terjadi, dan jumlah parameter serta FLOP di setiap bagian. Bagian ini membahas matematika Transformer ini dengan cermat, yang menunjukkan cara menghitung parameter dan FLOP untuk pelatihan dan inferensi. Hal ini memberi tahu Anda berapa banyak memori yang akan digunakan model, berapa banyak waktu yang akan Anda habiskan untuk komputasi atau komunikasi, dan kapan perhatian akan menjadi penting dibandingkan dengan blok feed-forward.

Terakhir, bagian ini membantu Anda mendapatkan jawaban atas pertanyaan mendasar: mengingat model berukuran tertentu dan dilengkapi dengan sejumlah chip, bagaimana cara memparalelkan model agar tetap dalam kondisi penskalaan yang kuat. Untuk menjawab pertanyaan ini, bagian ini membahas empat teknik paralelisme utama yang digunakan untuk membagi model di beberapa chip: data, tensor, pipeline, dan pakar. Dokumen ini juga menjelaskan teknik lain untuk mengurangi persyaratan memori seperti rematerialisasi, sharding model yang didukung ZeRO, pelepasan host, dan akumulasi gradien.

  • Pengantar operasi matematika Transformer: Bagian ini membahas matematika untuk menjawab pertanyaan tentang jumlah FLOP yang digunakan oleh Transformer selama penerusan dan penerusan balik, perhitungan untuk menghitung jumlah parameter, dan ukuran cache KV.
  • Paralelisasi transformer untuk pelatihan: Bagian ini menjelaskan proses untuk memaksimalkan efisiensi pelatihan dengan mengoordinasikan FSDP, pengelompokan Megatron, dan paralelisme pipeline. Bagian ini menjelaskan cara menentukan distribusi optimal untuk ukuran model dan ukuran batch tertentu di sejumlah chip tetap untuk mencapai throughput puncak.
    • Melatih Llama 3 di TPU: Subbagian ini menjelaskan cara melatih Llama 3 di TPU, berapa lama waktu yang mungkin diperlukan, dan berapa biayanya.
  • Penskalaan Transformer untuk inferensi: Setelah model dilatih, model tersebut perlu disajikan. Inferensi menambahkan pertimbangan baru, latensi, dan mengubah lanskap memori. Bagian ini menjelaskan cara kerja penayangan yang tidak digabungkan dan cara memikirkan cache KV.
    • Menyajikan Llama 3 di TPU: Subbagian ini menjelaskan cara menyajikan Llama 3 di TPU, perkiraan biayanya, serta pertukaran latensi dan throughput.

Bagian 3: Implementasi praktis

Bagian ini menjelaskan cara menerapkan konsep penskalaan menggunakan JAX, serta cara membuat profil dan men-debug kode Anda jika terjadi kesalahan.

  • Membuat profil program TPU: LLM sungguhan bersifat kompleks dan sulit dikembangkan, dioptimalkan, serta di-debug. Bagian ini menjelaskan stack JAX + XLA dan cara menggunakan profiler JAX/TensorBoard untuk men-debug dan memperbaiki masalah nyata.
  • Memprogram TPU di JAX: Bagian ini menjelaskan cara menggunakan JAX API untuk memparalelkan komputasi.