Pelatihan yang tangguh dengan Pathways

Jalur memberikan manfaat ketahanan dengan cara berikut:

  • Tunda-Lanjutkan: toleransi dalam menghadapi gangguan terencana seperti pemberitahuan pendahuluan tanpa perlu pengguna menulis kode penanganan pendahuluan kustom.
  • Pelatihan Elastis: toleransi dalam menghadapi kegagalan hardware yang tidak direncanakan tanpa menyebabkan klien mengalami error, tetapi mengharuskan pengguna menulis kode pemulihan khusus model.

    Sebelum memulai

    Pastikan Anda memiliki:

    Menangguhkan-melanjutkan

    Biasanya, GKE mengirimkan pemberitahuan preemption ke pod akselerator, sebelum pod tersebut di-preempt. Toleransi pengambilalihan jalur diaktifkan secara default di semua deployment cloud dan tugas akselerator Pathways memantau pemberitahuan ini.

    Saat pemberitahuan penghentian sementara tiba, Pathways terlebih dahulu menentukan apakah workload saat ini dapat dipulihkan - apakah Pathways dapat menyimpan dan memulihkan workload secara transparan. Jika demikian, GKE akan mencoba menangguhkan workload ML Anda secara transparan dengan menulis status saat ini ke penyimpanan persisten seperti Cloud Storage sebelum GKE menghapus tugas akselerator Anda. Saat GKE menjadwalkan ulang tugas Anda nanti, Pathways akan melanjutkan workload ML Anda dengan membaca kembali statusnya yang dipertahankan.

    Jika beban kerja tidak dapat dipulihkan, Pathways akan menghentikan tugas akselerator dan meneruskan kegagalan ke tugas Anda jika Pelatihan elastis dikonfigurasi. Jika pelatihan Elastic tidak dikonfigurasi, GKE akan memulai ulang seluruh workload berdasarkan kebijakan mulai ulang JobSet.

    Beban kerja ML umum yang ditentukan menggunakan JAX mengandalkan komponen XLA Pathways stateless yang dapat dipulihkan menggunakan snapshot memori bandwidth tinggi (HBM). Beberapa beban kerja ML, seperti yang ditentukan menggunakan JAX colocated python API, bergantung pada komponen Pathways stateful; komponen ini tidak dapat dipulihkan.

    Pelatihan elastis

    Pelatihan elastis memungkinkan tugas pelatihan Anda berlanjut meskipun terjadi kegagalan hardware. Hal ini dicapai melalui kombinasi kemampuan sistem Pathways dan logika pemulihan model yang ditentukan pengguna:

    • Deteksi kegagalan: Jika terjadi kegagalan hardware (misalnya, pekerja TPU mengalami error), sistem Pathways akan mendeteksinya dan memberi tahu tugas pelatihan pengguna melalui pengecualian saat data yang berada di hardware tersebut diakses pada waktu berikutnya. Notifikasi ini tidak menyebabkan beban kerja Anda error; notifikasi ini memungkinkan kode Anda menangani notifikasi dan mengonfigurasi ulang resource Anda untuk melanjutkan pemrosesan atau keluar dengan baik.
    • Handler elastisitas yang ditentukan pengguna: Kode model pengguna harus dapat menangani pengecualian ini. Inilah yang membuatnya menjadi "pemulihan khusus model".
      • Membuat snapshot: Pendekatan yang paling umum adalah menyimpan snapshot status model Anda secara berkala. Jika terjadi kegagalan, Anda dapat memuat dari snapshot terbaru untuk melanjutkan pelatihan.
      • Rekonfigurasi: Anda mungkin perlu mengonfigurasi ulang tugas pelatihan untuk menyesuaikan jumlah slice yang tersedia. Misalnya, jika satu slice berhenti berfungsi, Anda dapat mengurangi jumlah slice aktif sebanyak satu hingga penggantinya tersedia. Untuk mengetahui informasi selengkapnya, lihat Elastic Handler.
      • Pembaruan grafik Data/Komputasi: Kode Anda harus menangani setiap perubahan jumlah perangkat yang tersedia untuk komputasi Anda dengan membuat ulang grafik komputasi sesuai kebutuhan. Hal ini mungkin melibatkan pemartisian ulang data atau mengompilasi ulang model Anda.
    • Peran Pathways dalam pemulihan: Pathways menyediakan primitif untuk mendukung rekonfigurasi yang ditentukan pengguna:
      • Penggantian slice: Jika slice yang gagal diganti, klien dapat diberi tahu setelah slice baru tersedia. Kode Anda kemudian dapat dikonfigurasi ulang untuk menggunakan slice baru ini.
      • Pemulihan transparan: Pathways menangani detail tingkat bawah dari pemulihan, seperti membangun kembali koneksi ke bagian cluster yang berfungsi dengan baik.
    • Utilitas di pathwaysutils: Kumpulan utilitas Pathways yang ditentukan di pathways-utils.

    Menerapkan pengendali elastis

    Sebagian besar kode yang harus Anda tulis akan berada di handler elastis yang ditentukan pengguna. Handler ini bereaksi terhadap peristiwa elastis (seperti slice TPU yang menjadi tidak tersedia) dengan membuat ulang mesh dan menginisialisasi ulang loop pelatihan.

    Setiap beban kerja bersifat unik. Kompleksitas handler elastis dapat diskalakan dengan kompleksitas beban kerja. Input dan output handler harus berupa argumen dan nilai yang ditampilkan minimum yang diperlukan untuk menginisialisasi ulang loop pelatihan.

    def elastic_handler(elastic_utils, *args, **kwargs):
      mesh = initialize_mesh(**kwargs["mesh_kwargs"])
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **kwargs["initialize_training_loop_kwargs"])
    
      step, snapshot = elastic_utils.get_next_snapshot()
      state = initial_state.replace(**snapshot)
    
      return state, step, mesh, jitted_train_step, other_variables
    

    Memperbarui loop pelatihan

    Anda perlu melakukan perubahan berikut pada loop pelatihan:

    1. Membuat pengelola elastis
    2. Gabungkan loop pelatihan Anda di dalam blok try-except yang menangani jax.errors.JaxRuntimeError
    3. Dalam handler jax.errors.JaxRuntimeError, panggil maybe_reshard_down. Pengelola elastis akan melakukan pengecilan ulang jika error terkait dengan peristiwa elastis atau memunculkannya kembali.
    4. Panggil maybe_snapshot dan maybe_reshard_up di akhir loop pelatihan
    import pathwaysutils
    from pathwaysutils.elastic import manager
    
    pathwaysutils.initialize()
    
    def initialize_mesh(**kwargs):
      ...
    
    
    def initialize_training_loop(**kwargs):
      ...
    
    
    def train_loop(
        final_step,
        elastic_manager,
        mesh_kwargs,
        initialize_training_loop_kwargs,
    ):
      mesh = initialize_mesh(**mesh_kwargs)
      initial_state, initial_step, jitted_train_step, other_variables =
          initialize_training_loop(mesh, **initialize_training_loop_kwargs)
    
      step = initial_step
      while step < final_step:
        try:
          state = jitted_train_step(state)
    
          elastic_manager.maybe_snapshot(step=step, snapshot=state)
          handler_returns = elastic_manager.maybe_reshard_up(
              step=step,
              snapshot=state,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
          step += 1
        except jax.errors.JaxRuntimeError as error:
          handler_returns = elastic_manager.maybe_reshard_down(
              error=error,
              elastic_handler=elastic_handler,
              handler_args=(),
              handler_kwargs=dict(
                  mesh_kwargs=mesh_kwargs,
                  initialize_training_loop_kwargs=initialize_training_loop_kwargs,
              ),
          )
          if handler_returns:
            state, step, mesh, jitted_train_step, other_variables = handler_returns
    
      return state
    
    
    def main():
      elastic_manager = manager.Manager(
          devices=jax.devices(),
          snapshot_period=10,
          snapshot_buffer_size=1,
          reshard_check_period=5,
          max_elastic_down_event_count=10,
          max_reshard_retry_count=3,
      )
    
      train_loop(100, elastic_manager, {}, {})
    

    Mengonfigurasi pengelola elastis

    Pengelola elastis dapat dikonfigurasi dengan beberapa cara yang berbeda. Frekuensi pengambilan snapshot ditentukan oleh periode snapshot. Periode pengambilan snapshot memengaruhi jumlah rata-rata langkah yang hilang karena peristiwa elastis. Periode pemeriksaan resharing menentukan seberapa sering loop pelatihan Anda akan melakukan polling untuk ketersediaan slice. max_elastic_down_event_count memungkinkan Anda menetapkan jumlah peristiwa elastis yang disebabkan oleh kehilangan slice yang akan didukung oleh loop pelatihan Anda. max_reshard_retry_count menentukan frekuensi pengelola elastis harus mencoba ulang perubahan partisi. Pengelola adalah objek singleton dan hanya boleh dibuat satu kali.

    Snapshot

    Berdasarkan konfigurasi pengelola elastis, fungsi dapat mengambil snapshot data ke dalam memori host yang akan tersedia untuk digunakan oleh handler elastis Anda selama peristiwa elastis.

    Mengurangi sharding

    Setelah menangkap jax.errors.JaxRuntimeError, Pathways akan memeriksa apakah error disebabkan oleh peristiwa elastis karena slice yang hilang. Jika ya, fungsi ini akan memanggil handler elastis dalam loop hingga berhasil atau mencapai upaya percobaan ulang maksimum. Jika error tidak disebabkan oleh peristiwa elastis, error akan muncul lagi. Nilai hasil dari elastic handler diteruskan ke pemanggil.

    Meningkatkan sharding

    Berdasarkan konfigurasi pengelola elastis dan jika ada slice yang tidak tersedia, Pathways akan memeriksa apakah ada slice tambahan yang tersedia. Jika demikian, snapshot akan segera disimpan (jika snapshot yang sudah ada sebelumnya untuk langkah saat ini belum diambil) dan memanggil handler elastis dalam loop hingga berhasil atau jumlah maksimum upaya percobaan ulang tercapai. Jika terjadi penyesuaian ukuran ulang, nilai yang ditampilkan dari pengendali elastis akan diteruskan ke pemanggil. Jika tidak, None akan ditampilkan.

    Hot-swap

    Hot-Swap mengacu pada fitur GKE JobSet API di mana tugas berprioritas lebih tinggi dapat dengan cepat mengambil alih resource dari tugas berprioritas lebih rendah, sehingga meminimalkan waktu henti dan memastikan pemulihan yang lebih cepat.

    Saat JobSet dibuat, GKE menjadwalkan workload di beberapa slice, seperti yang ditentukan dalam konfigurasi JobSet. Jika terjadi kegagalan hardware pada satu atau beberapa slice, Pod yang terpengaruh akan ditandai sebagai gagal. Saat menjadwalkan ulang Jobset ini, jika Anda telah memilih untuk menyimpan slice cadangan di cluster GKE yang dapat digunakan untuk Tugas dengan prioritas lebih rendah, sistem JobSet akan memetakan ulang workload slice yang gagal dari tugas dengan prioritas lebih tinggi ke slice cadangan yang digunakan oleh tugas dengan prioritas lebih rendah dalam cluster GKE yang sama. Pemetaan ulang ini biasanya memerlukan waktu kurang dari satu menit.

    Setelah JobSet dimulai ulang, penggantian cepat dapat terjadi dalam situasi berikut:

    1. Mode Default: Jika slice TPU cadangan yang tidak digunakan tersedia dalam cluster yang sama, penjadwal Kubernetes akan memprioritaskan penjadwalan ulang tugas yang dimulai ulang ke slice ini, daripada menunggu slice yang gagal diperbaiki. Hal ini memberikan pemulihan yang lebih cepat.
    2. Workload Heterogen: Di cluster yang menjalankan beberapa workload dengan PriorityClass Kubernetes yang dikonfigurasi, JobSet yang dimulai ulang dapat memicu penggantian langsung. Jika afinitas tugas yang dimulai ulang cocok dengan resource tugas berprioritas lebih rendah, Kubernetes akan mendahului tugas berprioritas lebih rendah, sehingga tugas berprioritas lebih tinggi dapat segera dimulai. Misalnya, Anda dapat mengonfigurasi pod pekerja Pathways dengan prioritas yang berbeda menggunakan PriorityClass.

    Untuk menggunakan prioritas di cluster Anda, tentukan class prioritas, misalnya:

    kind: PriorityClass
    metadata:
      name: high-prior-job
    value: 2000
    globalDefault: false
    description: "This priority class should be used for high priority job."
    

    Terapkan YAML ini ke cluster GKE Anda:

    kubectl apply -f high-prior-job.yaml
    

    Selanjutnya, lampirkan class prioritas baru ke tugas pekerja Pathways Anda dengan menambahkan teks berikut ke podspec Pod pathways-worker Anda.

    priorityClassName: high-prior-job
    

    Langkah berikutnya