Menjalankan workload batch dengan Pathways

Untuk tujuan dokumen ini, workload batch ditentukan sebagai workload JAX yang dieksekusi hingga selesai dan di-deploy dalam cluster GKE yang sama dengan cluster Pathways, khususnya bersama komponen pengontrol Pathways (server proxy IFRT dan pengelola resource Pathways). Penyelesaian workload JAX akan menghentikan komponen cluster Pathways. Panduan ini menggunakan workload pelatihan JAX untuk mendemonstrasikan hal ini.

Sebelum memulai

Pastikan Anda memiliki:

Membangun image pelatihan menggunakan Maxtext

MaxText adalah project model bahasa besar (LLM) open source yang dikembangkan oleh Google. Model ini ditulis dalam JAX dan dirancang agar berperforma tinggi dan skalabel, serta berjalan secara efisien di TPU dan GPU Google Cloud.

Untuk membangun image Docker MaxText menggunakan JAX stabil versi terbaru dari repositori GitHub OSS, jalankan perintah berikut:

git clone https://github.com/AI-Hypercomputer/maxtext
cd maxtext/dependencies/scripts
gcloud config set project PROJECT
bash ./docker_build_dependency_image.sh MODE=stable
gcloud auth configure-docker
bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME=USER_runner # This script needs bash version >= 4.2 to execute.

Perintah ini mengirim image Kubernetes MaxText ke gcr.io/$PROJECT/${USER}_runner. Anda dapat menggunakan image Docker ini untuk menjalankan pelatihan di TPU menggunakan backend Pathways.

Menjalankan workload batch menggunakan PathwaysJob API

Manifes berikut men-deploy komponen Pathways dan menjalankan beban kerja MaxText menggunakan PathwaysJob API. Workload dienkapsulasi dalam container main dan menjalankan train.py.

Salin YAML berikut ke dalam file bernama pathways-job-batch-training.yaml dan perbarui nilai yang dapat diedit.

apiVersion: pathways-job.pathways.domain/v1
kind: PathwaysJob
metadata:
  name: pathways-USER
spec:
  maxRestarts: MAX_RESTARTS
  workers:
  - type: TPU_MACHINE_TYPE
    topology: TOPOLOGY
    numSlices: WORKLOAD_NODEPOOL_COUNT
  pathwaysDir: gs://BUCKET_NAME
  controller:
    deploymentMode: default
    template:
      spec:
        containers:
        - name: main
          image: gcr.io/PROJECT/USER_runner
          command:
          - bash
          - -c
          - |
            python3 -m MaxText.train /deps/src/MaxText/configs/base.yml \
            base_output_directory=gs://BUCKET_NAME \
            run_name=RUN_NAME \
            per_device_batch_size=1 \
            enable_checkpointing=false \
            remat_policy=full \
            global_parameter_scale=1 \
            steps=20 \
            max_target_length=2048 \
            use_iota_embed=true \
            reuse_example_batch=1 \
            dataset_type=synthetic \
            attention=flash \
            gcs_metrics=True \
            enable_single_controller=True

Ganti kode berikut:

  • USER : ID Google Cloud pengguna Anda
  • MAX_RESTARTS : jumlah maksimum upaya memulai ulang Tugas
  • TPU_MACHINE_TYPE : jenis mesin TPU
  • TOPOLOGY : topologi TPU v4 atau yang lebih baru. Untuk mengetahui informasi selengkapnya tentang versi TPU dan topologi yang didukung, lihat Versi TPU
  • WORKLOAD_NODEPOOL_COUNT : jumlah node pool yang digunakan oleh beban kerja Pathways
  • BUCKET_NAME : bucket Cloud Storage untuk menyimpan file sementara
  • PROJECT : Google Cloud project ID Anda
  • RUN_NAME : nama yang ditetapkan pengguna untuk mengidentifikasi eksekusi alur kerja

Anda dapat men-deploy YAML PathwaysJob sebagai berikut:

kubectl apply -f pathways-job-batch-training.yaml

Untuk melihat instance PathwaysJob yang dibuat oleh perintah sebelumnya, gunakan:

kubectl get pathwaysjob

Output-nya akan terlihat seperti ini:

NAME             AGE
pathways-trial   9s

Untuk mengubah atribut instance PathwaysJob, hapus instance PathwaysJob, ubah YAML, lalu terapkan untuk membuat instance PathwaysJob baru.

Anda dapat mengikuti progres workload dengan membuka Logs Explorer untuk container JAX dengan memilih main di bagian filter Nama Container.

Anda akan melihat log seperti berikut yang menunjukkan bahwa pelatihan sedang berlangsung. Workload akan selesai setelah 30 langkah.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

Untuk menghapus instance PathwaysJob, Anda dapat menggunakan perintah berikut:

kubectl delete -f pathways-job-batch-training.yaml

Menjalankan workload batch menggunakan XPK

Sekarang Anda dapat mengirimkan image Docker Maxtext bawaan menggunakan XPK dengan perintah yang sama dengan yang Anda gunakan sebelumnya.

xpk workload create-pathways \
--workload=WORKLOAD \
--cluster=CLUSTER \
--num-slices=WORKLOAD_NODEPOOL_COUNT \
--tpu-type=TPU_TYPE \
--project=PROJECT \
--zone=ZONE \
--docker-image='gcr.io/PROJECT/USER_runner' \
--command="python3 -m MaxText.train /deps/src/MaxText/configs/base.yml base_output_directory=gs://BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name=RUN_NAME-pathways-job"

Ganti kode berikut:

  • WORKLOAD: nama unik untuk mengidentifikasi workload Anda
  • CLUSTER: nama cluster GKE Anda
  • WORKLOAD_NODEPOOL_COUNT : jumlah maksimum tugas dapat dimulai ulang
  • TPU_TYPE: jenis TPU menentukan versi dan ukuran Cloud TPU yang ingin Anda buat. Untuk mengetahui informasi selengkapnya tentang jenis TPU yang didukung untuk setiap versi TPU, lihat Versi TPU
  • PROJECT : Project ID Google Cloud Anda
  • ZONE: zona tempat Anda berencana menjalankan workload
  • USER : ID Google Cloud pengguna Anda
  • RUN_NAME : nama yang ditetapkan pengguna untuk mengidentifikasi eksekusi alur kerja

Anda akan melihat output seperti berikut:

[XPK] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT

Gunakan link di output dari perintah XPK sebelumnya untuk memantau progres beban kerja Anda. Anda dapat memfilter log untuk penampung JAX dengan memilih jax-tpu di bagian filter Nama Penampung.

completed step: 1, seconds: 0.484, TFLOP/s/device: 87.349, Tokens/s/device: 2117.382, total_weights: 2945, loss: 10.888
completed step: 2, seconds: 0.407, TFLOP/s/device: 103.699, Tokens/s/device: 2513.735, total_weights: 3253, loss: 9.697
completed step: 3, seconds: 0.248, TFLOP/s/device: 170.300, Tokens/s/device: 4128.167, total_weights: 3154, loss: 9.641
completed step: 4, seconds: 0.216, TFLOP/s/device: 195.122, Tokens/s/device: 4729.880, total_weights: 3119, loss: 9.547
completed step: 5, seconds: 0.272, TFLOP/s/device: 155.298, Tokens/s/device: 3764.512, total_weights: 2837, loss: 10.179
completed step: 6, seconds: 0.472, TFLOP/s/device: 89.489, Tokens/s/device: 2169.266, total_weights: 3069, loss: 9.776

Workload selesai setelah jumlah langkah yang ditentukan. Jika Anda ingin menghentikannya sebelum waktunya, gunakan perintah berikut:

xpk workload delete --workload=WORKLOAD --cluster=CLUSTER --project=PROJECT --zone=ZONE

Langkah berikutnya