Pengoptimalan performa TPU7x (Ironwood)
Panduan ini menjelaskan beberapa metode untuk mengoptimalkan performa dengan TPU7x (Ironwood) dengan mengelola pergerakan data secara efisien di antara sistem memori bertingkatnya. Hal ini mencakup teknik seperti pelatihan presisi rendah, sharding, pengoptimalan komunikasi, rematerialisasi aktivasi, penyesuaian memori virtual yang tercakup, dan kernel akselerator kustom.
Untuk mengoptimalkan performa dengan TPU7x, Anda harus memahami arsitektur Ironwood terlebih dahulu, khususnya hierarki memori dan topologi interkoneksi. Untuk informasi selengkapnya, lihat TPU7x (Ironwood).
Pelatihan presisi rendah dengan FP8
FP8 (floating point 8 bit) adalah format data numerik yang efisien dan digunakan terutama untuk mempercepat pelatihan dan inferensi model. Dengan merepresentasikan angka menggunakan 8 bit – bukan format 16-bit standar (FP16 atau BF16) dan 32-bit (FP32) – TPU dapat memproses data secara signifikan lebih cepat dan menggunakan lebih sedikit memori.
TPU7x mendukung akselerasi hardware bawaan untuk jenis data FP8, yang menawarkan performa teoretis puncak sebesar 4614 TFLOPS per chip. Kemampuan ini dapat menghasilkan waktu pelatihan end-to-end yang jauh lebih cepat. Untuk operasi yang kompatibel, terutama perkalian matriks padat yang umum untuk beban kerja AI, penggunaan FP8 dapat menghasilkan peningkatan performa sebesar 1,3x dibandingkan pelatihan BF16 standar. Dibandingkan dengan BF16, FP8 menggandakan FLOP puncak dan mengurangi separuh footprint memori untuk bobot dan aktivasi. FP8 harus menjadi tuas penyesuaian utama untuk beban kerja yang terikat komputasi dan skenario yang dibatasi oleh kapasitas atau bandwidth memori.
Penggunaan FP8 menawarkan manfaat performa berikut:
- Mengurangi tekanan memori bandwidth tinggi (HBM): Jejak memori yang lebih kecil memungkinkan model yang lebih besar, atau model dengan cache KV yang lebih besar selama inferensi, agar sepenuhnya sesuai dengan HBM 192 GB. Hal ini menghindari pelepasan yang mahal ke memori host yang lebih lambat.
- Ukuran tumpukan efektif yang lebih besar: Dengan mengurangi memori yang diperlukan untuk aktivasi, FP8 memungkinkan penggunaan ukuran tumpukan yang lebih besar. Hal ini meningkatkan paralelisme data dan dapat menghasilkan throughput yang lebih tinggi dan pemanfaatan unit komputasi yang lebih baik.
- Persyaratan bandwidth memori yang lebih rendah: Memindahkan setengah jumlah data untuk setiap operasi mengurangi permintaan pada jalur data HBM ke MXU. Pada sistem yang pergerakan datanya sering menjadi hambatan, hal ini membantu menjaga MXU tetap penuh dengan pekerjaan.
Penggunaan FP8 dengan penurunan performa nol atau terbatas memerlukan pemilihan teknik kuantisasi yang cermat. Berikut beberapa praktik terbaik yang perlu dipertimbangkan untuk pelatihan FP8:
- Granularitas penskalaan: Mulailah dengan penskalaan per-tensor sebagai dasar pengukuran. Jika ada masalah kualitas atau performa, beralihlah ke penskalaan per sumbu. Penskalaan subchannel mungkin tidak diperlukan.
- Mode penskalaan: Penskalaan dinamis, yang menghitung faktor penskalaan saat runtime, adalah default yang baik untuk mempertahankan kualitas. Meskipun penskalaan statis dapat menawarkan peningkatan performa yang signifikan dengan menghilangkan komputasi, penskalaan ini memerlukan pembuatan profil yang cermat untuk menentukan faktor penskalaan yang benar dan mungkin tidak cocok untuk semua kasus penggunaan, terutama saat konfigurasi model berubah. Sebaliknya, beberapa model dan konfigurasi yang andal dapat memperbaiki skala ke batas FP8 untuk bobot atau aktivasi, sehingga Anda dapat mengurangi overhead kuantisasi sambil mempertahankan akurasi dan meningkatkan performa.
- Format FP8 (E4M3 dan E5M2): Pendekatan umum dan efektif adalah menggunakan campuran format FP8. Misalnya, gunakan E4M3 untuk bobot dan aktivasi dalam forward pass untuk memanfaatkan presisi E4M3 yang lebih tinggi, dan gunakan E5M2 untuk gradien dalam backward pass untuk mengakomodasi rentang dinamis gradien yang lebih luas.
- Pembulatan: Menggunakan "bulatkan ke bilangan genap terdekat" (RNE) dan bukan pembulatan stokastik untuk gradien dapat mempertahankan kualitas sekaligus menawarkan performa dan reproduksibilitas yang lebih baik.
- Mengaktifkan FP8 di MaxText:
MaxText mendukung pelatihan FP8
melalui library kuantisasi QWIX. Untuk mengaktifkan kuantisasi, tetapkan
flag berikut dalam konfigurasi Anda:
use_qwix_quantization=true.
Sharding dan paralelisme
Sharding adalah proses membagi model besar atau data pelatihannya menjadi beberapa bagian yang lebih kecil dan mendistribusikannya ke beberapa chip atau core TPU. Memilih strategi sharding yang tepat penting untuk mencapai performa tinggi di TPU7x.
Pendekatan sederhana yang hanya memaksimalkan tingkat paralelisme sering kali menghasilkan performa yang buruk karena terikat dengan komunikasi. Pendekatan terbaik sering kali adalah memilih strategi sharding paling sederhana yang memenuhi batasan memori, karena hal ini meminimalkan overhead komunikasi dan memungkinkan unit komputasi digunakan secara efisien.
Sebelum memilih strategi sharding, langkah pertama dalam upaya penyesuaian performa apa pun harus berupa analisis intensitas aritmatika. Analisis ini menentukan apakah komputasi tertentu dibatasi oleh komputasi, bandwidth memori, atau bandwidth interkoneksi. Metrik ini dihitung sebagai rasio operasi floating point terhadap byte data yang harus dipindahkan.
Intensitas aritmatika yang tinggi menunjukkan workload terikat komputasi. Intensitas aritmatika yang rendah menunjukkan beban kerja yang terikat dengan memori atau komunikasi, yang performanya dibatasi oleh kecepatan data dapat dipindahkan dari HBM atau di seluruh jaringan ICI. Analisis ini akan menentukan ukuran batch dan strategi sharding yang ideal. Misalnya, workload yang terikat komunikasi tidak akan mendapatkan manfaat dari strategi sharding yang memperkenalkan lebih banyak komunikasi, seperti paralelisme tensor tingkat tinggi.
Framework keputusan strategi sharding
MaxText menawarkan berbagai strategi sharding. Pilihan optimal bergantung pada arsitektur model, panjang urutan, dan kebutuhan untuk menyeimbangkan beban komputasi dengan overhead komunikasi.
- Fully Sharded Data Parallelism (FSDP): Ini adalah strategi default yang lebih disukai untuk paralelisme data. FSDP membagi bobot model, gradien, dan status pengoptimal di seluruh perangkat paralel data. Selama komputasi, setiap perangkat melakukan operasi All-Gather untuk mengambil bobot penuh yang diperlukan untuk microbatch lokalnya. FSDP sangat efektif selama ukuran batch per perangkat cukup besar untuk menyembunyikan latensi komunikasi All-Gather ini. Untuk model Mixture-of-Experts (MoE), perhitungan intensitas aritmetika harus memperhitungkan kejarangan.
- Paralelisme Tensor (TP): TP membagi tensor individual di seluruh perangkat. Biasanya, tensor adalah matriks bobot dalam multilayer perceptron (MLP) dan blok perhatian. Intensitas aritmetika hardware yang tinggi (11,5k) menimbulkan persyaratan yang sangat tinggi pada dimensi model agar TP layak digunakan melalui ICI, dan upaya untuk menggunakan TP dapat menyebabkan sistem terikat komunikasi.
- Paralelisme Pakar (EP): Ini adalah strategi standar dan diperlukan untuk melatih model MoE. EP membagi lapisan "pakar" di seluruh rangkaian perangkat, dan kolektif komunikasi All-to-All digunakan untuk merutekan token ke perangkat pakar yang ditentukan. EP dapat efisien jika dimensi MLP model cukup besar untuk mendekati roofline.
- Paralelisme Konteks (CP): CP adalah strategi khusus yang penting untuk melatih model dengan panjang urutan yang sangat panjang. Fungsi utamanya adalah mengelola konsumsi memori aktivasi, yang tumbuh secara kuadratik dengan panjang urutan dan dapat melebihi kapasitas HBM. CP membagi dimensi urutan tensor aktivasi, yang memungkinkan penggunaan ukuran batch per perangkat fraksional. Karena CP memperkenalkan lebih banyak komunikasi daripada FSDP, aturan umumnya adalah menggunakan tingkat CP minimum yang diperlukan untuk memenuhi batasan memori dan memastikan shard sumbu batch tetap berupa bilangan bulat.
Tabel berikut memetakan jenis beban kerja umum ke strategi sharding yang optimal:
| Jenis workload | Sharding utama yang direkomendasikan | Sharding sekunder | Hambatan utama | Alasan |
|---|---|---|---|---|
| Model padat - urutan pendek | FSDP | T/A | Rematerialisasi, FF Matmuls | FSDP memberikan keseimbangan terbaik. Dengan urutan pendek, memori aktivasi mungkin bukan masalah utama. Kuncinya adalah batch global yang cukup besar untuk menyembunyikan bobot pengumpulan semua FSDP. Saat ukuran batch bertambah, ukuran aktivasi juga bertambah, dan kebijakan rematerialisasi yang sesuai diperlukan untuk memastikan konfigurasi ini tidak kehabisan memori. |
| Model padat - urutan panjang | FSDP | CP | Perhatian kilat, memori aktivasi | Memori aktivasi menjadi batasan utama. CP diperlukan untuk mengaktifkan ukuran batch per perangkat fraksional dan menghindari masalah kehabisan memori (OOM). Perhatian kilat adalah sumber utama komputasi dan waktu yang terbuang. |
| Model MoE - urutan pendek | FSDP + EP | T/A | All-to-All (Perutean pakar), rematerialisasi | Model MoE memerlukan EP untuk membagi pakar. Komunikasi All-to-All untuk perutean token adalah hambatan utama yang harus diatasi. Rematerialisasi juga merupakan sumber limbah yang signifikan. |
| Model MoE - skala sangat besar | FSDP + EP + PP | Paralelisme model (MP) | Semua hambatan yang disebutkan sebelumnya, ditambah balon pipeline | Untuk model yang melebihi memori satu pod, PP diperlukan untuk memecah lapisan di seluruh pod. Hal ini memperkenalkan komunikasi DCN dan biaya overhead balon pipeline. Ini adalah konfigurasi yang sangat kompleks yang memerlukan penyesuaian yang cermat. |
Pengoptimalan komunikasi
Mekanisme utama untuk tumpang-tindih komunikasi dan komputasi di TPU7x disebut SparseCore Collective Offloading. Arsitektur Ironwood mencakup unit SparseCore khusus, yang bertindak sebagai thread kontrol independen yang mampu mengelola pergerakan data melalui fabric ICI. Hal ini memungkinkan operasi komunikasi kolektif (seperti All-Gather atau Reduce-Scatter) dieksekusi secara paralel dengan komputasi utama yang terjadi di TensorCore. Ini adalah metode yang direkomendasikan untuk kolektif asinkron di TPU7x. Gunakan flag yang direkomendasikan untuk mengaktifkan pelepasan tugas untuk kolektifitas yang paling umum.
Rematerialisasi aktivasi
Rematerialisasi aktivasi, juga dikenal sebagai checkpointing gradien, adalah teknik mendasar untuk mengurangi footprint HBM suatu model. Daripada menyimpan semua aktivasi perantara dari forward pass di HBM untuk digunakan selama backward pass, teknik ini hanya menyimpan beberapa aktivasi utama (titik pemeriksaan) dan menghitung ulang aktivasi lainnya sesuai permintaan selama backward pass. Hal ini menghemat memori dalam jumlah yang signifikan dengan mengorbankan peningkatan komputasi (sekitar 25-30% FLOP tambahan untuk blok transformer standar).
Keputusan tentang seberapa agresif penerapan rematerialisasi adalah parameter penyesuaian penting yang sepenuhnya bergantung pada hambatan utama, yang sering kali bervariasi dengan panjang urutan.
Untuk beban kerja urutan panjang (seperti 128k): Dalam kasus ini, ukuran tensor aktivasi adalah konsumen HBM yang dominan. Workload biasanya terikat memori. Oleh karena itu, menerapkan kebijakan rematerialisasi yang agresif sangat bermanfaat. Penghematan memori memungkinkan pelatihan dilanjutkan tanpa error kehabisan memori dan juga memungkinkan ukuran batch yang lebih besar, dan overhead komputasi untuk menghitung ulang adalah pertukaran yang berharga.
Untuk beban kerja urutan pendek (seperti 8k): Dalam kasus ini, memori aktivasi tidak terlalu menjadi masalah, dan beban kerja cenderung terikat komputasi. Overhead komputasi rematerialisasi dapat menjadi sumber inefisiensi terbesar.
Menyesuaikan kebijakan rematerialisasi di MaxText
MaxText memberikan kontrol terperinci atas rematerialisasi melalui serangkaian kebijakan preset dan kustom, yang dikonfigurasi menggunakan tanda remat_policy.
Kebijakan preset
MaxText menawarkan kebijakan bawaan berikut:
full: Kebijakan paling agresif, yang merealisasikan ulang hampir semuanya. Opsi ini meminimalkan penggunaan HBM, tetapi memaksimalkan beban komputasi ulang. Ideal untuk skenario urutan panjang dengan batasan memori yang sangat ketat.minimal: Kebijakan yang paling tidak agresif, menyimpan sebagian besar aktivasi. Hal ini memaksimalkan penggunaan HBM, tetapi meminimalkan penghitungan ulang. Paling cocok untuk workload terikat komputasi dengan urutan pendek, yang tidak memerlukan banyak memori.- Kebijakan menengah: Opsi seperti
save_dot_with_context_except_mlp,save_qkv_proj, dansave_out_projmemberikan berbagai trade-off dengan membuat titik pemeriksaan secara selektif pada output operasi dot-product yang mahal sekaligus merealisasikan kembali operasi element-wise yang lebih murah.
Kebijakan kustom
Untuk tingkat kontrol yang lebih besar, Anda dapat menyetel remat_policy ke custom. Hal ini memungkinkan Anda menentukan perilaku untuk setiap lapisan dalam modul dekode model. Setiap lapisan dapat diberi salah satu dari tiga perilaku:
device: Pengaktifan disimpan di HBM pada perangkat TPU.remat: Aktivasi dibatalkan dan akan diwujudkan kembali selama backward pass.offload: Aktivasi dipindahkan dari HBM ke memori host CPU, sehingga mengosongkan HBM dengan mengorbankan latensi transfer PCIe.
Penyesuaian VMEM yang tercakup
Performa kernel, seperti flash attention, bergantung pada ukuran petak yang dipilih dalam kernel, yang ukurannya dibatasi oleh memori vektor (VMEM) yang tersedia. Chip TPU7x memiliki VMEM 64 MB, yang dapat dibagi antara cakupan saat ini (VMEM tercakup) dan pengambilan data bobot di masa mendatang. Meningkatkan VMEM yang tercakup memungkinkan peningkatan ukuran petak di kernel, yang berpotensi mengurangi penundaan memori dan meningkatkan performa kernel. Anda dapat mengubah ukuran VMEM yang tercakup dengan menetapkan
xla_tpu_scoped_vmem_limit_kib (dalam LIBTPU_INIT_ARGS), yang dapat digunakan untuk
menjelajahi performa kernel serta batas performa end-to-end.
Mengoptimalkan ukuran VMEM yang tercakup secara tidak langsung dapat memengaruhi performa kernel Pallas kustom karena peningkatan VMEM yang tercakup akan membuka ruang penelusuran hyperparameter yang lebih besar untuk ukuran petak dalam kernel.
Kernel Tokamax
Tokamax, library kernel JAX berperforma tinggi dengan banyak kernel TPU yang sangat dioptimalkan, mengatasi beberapa hambatan umum khusus hardware:
- Perhatian splash: Perhatian splash digunakan sebagai penerapan perhatian utama untuk menghilangkan hambatan HBM dari perhatian standar dan menggunakan penerapan perhatian yang paling efisien di TPU.
- Perkalian Matriks yang Dikelompokkan Megablox (GMM): Untuk workload MoE, Megablox secara efisien menangani perkalian matriks yang dikelompokkan dengan menghitung representasi aktivasi yang tidak beraturan. Operasi ini secara efisien memetakan dimensi tidak beraturan, menghitung perkalian matriks antara grup baris tidak beraturan di LHS, dan matriks pakar yang sesuai, sehingga tidak perlu mengisi batch ke ukuran tetap.
- Penyesuaian empiris dengan
tune-jax: Librarytune-jaxmemiliki utilitas untuk melakukan penelusuran empiris untuk ukuran blok yang optimal. Ukuran kernel default sering kali tidak optimal; penyesuaian memungkinkan pemilihan ukuran petak VMEM yang kompatibel dengan hardware untuk memaksimalkan penggunaan hardware. - Estimasi logit maks: Kernel perhatian Tokamax Splash dapat dioptimalkan lebih lanjut dengan menetapkan nilai untuk
max_logit_const. Jika disetel, nilai ini akan menggantikan penghitungan pengurangan logit maks selama operasi softmax perhatian (softmax(Q * KT)), sehingga mengurangi beberapa overhead komputasi dan sinkronisasi. Di MaxText, fitur ini diterapkan oleh konfigurasiuse_max_logits_estimate, yang dapat disetel keNone(dinonaktifkan) atau nilai floating point. Pastikan rentang logit model spesifik Anda tetap kompatibel dengan estimasi untuk mencegah overflow numerik. Pengujian konvergensi direkomendasikan jika nilai ini ditetapkan.