Melakukan inferensi multihost menggunakan Pathways

Inferensi multihost adalah metode menjalankan inferensi model yang mendistribusikan model di beberapa host akselerator. Hal ini memungkinkan inferensi model besar yang tidak sesuai dengan satu host. Jalur dapat di-deploy untuk kasus penggunaan inferensi multihost batch dan real-time.

Sebelum memulai

Pastikan Anda memiliki:

Menjalankan Inferensi batch menggunakan JetStream

JetStream adalah mesin yang dioptimalkan untuk throughput dan memori untuk inferensi model bahasa besar (LLM) di perangkat XLA, terutama Tensor Processing Unit (TPU) yang ditulis dalam JAX.

Anda dapat menggunakan image Docker JetStream yang telah dibuat sebelumnya untuk menjalankan beban kerja inferensi batch, seperti yang ditunjukkan dalam YAML berikut. Container ini dibangun dari project JetStream OSS. Untuk mengetahui informasi selengkapnya tentang tanda MaxText-JetStream, lihat Tanda server MaxText JetStream. Contoh berikut menggunakan chip Trillium (v6e-16) untuk memuat checkpoint int8 Llama3.1-405b dan melakukan inferensi di atasnya. Contoh ini mengasumsikan bahwa Anda sudah memiliki cluster GKE dengan setidaknya satu kumpulan node v6e-16 di dalamnya.

Mulai server model dan Pathways

  1. Dapatkan kredensial ke cluster dan tambahkan ke konteks kubectl lokal Anda.
          gcloud container clusters get-credentials $CLUSTER \
          --zone=$ZONE \
          --project=$PROJECT \
          && kubectl config set-context --current --namespace=default
        
  2. Men-deploy LeaderWorkerSet (LWS) API.
          VERSION=v0.4.0
          kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
        
  3. Salin dan tempel YAML berikut ke dalam file bernama pathways-job.yaml: YAML ini telah dioptimalkan untuk bentuk irisan v6e-16. Untuk mengetahui informasi selengkapnya tentang cara mengonversi checkpoint Meta menjadi checkpoint yang kompatibel dengan JAX, ikuti panduan pembuatan checkpoint di Membuat checkpoint inferensi. Sebagai contoh, petunjuk untuk Llama3.1-405B disediakan di sini Konversi titik pemeriksaan untuk Llama3.1-405B.
        apiVersion: leaderworkerset.x-k8s.io/v1
        kind: LeaderWorkerSet
        metadata:
          name: jetstream-pathways
          annotations:
            leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
        spec:
          replicas: 1
          leaderWorkerTemplate:
            leaderTemplate:
              metadata:
                labels: 
                  app: jetstream-pathways
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: pathways-proxy
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
                  args:
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --server_port=38681
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38681
                - name: pathways-rm
                  env:       
                  - name: HOST_ADDRESS
                    value: "$(LWS_LEADER_ADDRESS)"
                  - name: TPU_SKIP_MDS_QUERY
                    value: "true"
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  args:
                  - --server_port=38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  - --node_type=resource_manager
                  - --instance_count=1
                  - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4
                  - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38677
                - name: jax-tpu
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
                  # Optimized settings used to serve Llama3.1-405b.
                  args:
                  - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml
                  - model_name=llama3.1-405b
                  - load_parameters_path=GCS_CHECKPOINT_PATH
                  - max_prefill_predict_length=1024
                  - max_target_length=2048
                  - async_checkpointing=false
                  - steps=1
                  - ici_fsdp_parallelism=1
                  - ici_autoregressive_parallelism=2
                  - ici_tensor_parallelism=8
                  - scan_layers=false
                  - weight_dtype=bfloat16
                  - per_device_batch_size=10
                  - enable_single_controller=true
                  - quantization=int8
                  - quantize_kvcache=true
                  - checkpoint_is_quantized=true
                  - enable_model_warmup=true
                  imagePullPolicy: Always
                  ports: 
                  - containerPort: 9000
                  startupProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 1
                    initialDelaySeconds: 900
                    failureThreshold: 10000
                  livenessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                  readinessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                - name: jetstream-http
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 8000
            # The size variable defines the number of worker nodes to be created.
            # It must be equal to the number of hosts + 1 (for the leader node).
            size: 5
            workerTemplate:
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: worker
                  args:
                  - --server_port=38679
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38679
                  resources:
                    limits:
                      google.com/tpu: "4"
        --- 
        apiVersion: v1
        kind: Service
        metadata:
          name: jetstream-svc
        spec:
          selector:
            app: jetstream-pathways
          ports:
          - protocol: TCP
            name: jetstream-http
            port: 8000
            targetPort: 8000
        
    Ganti kode berikut:
    • TPU_ACCELERATOR_TYPE: Jenis akselerator TPU. Contoh, tpu-v6e-slice.
    • TPU_TOPOLOGY: Topologi TPU. Contoh, 2x4.
    • GCS_CHECKPOINT_PATH: Jalur GCS ke titik pemeriksaan.
    Terapkan YAML ini. Tunggu hingga PathwaysJob dijadwalkan. Setelah dijadwalkan, server model mungkin memerlukan waktu beberapa saat untuk memulihkan titik pemeriksaan. Untuk model 405B, proses ini memerlukan waktu sekitar 7 menit.
  4. Lihat log Kubernetes untuk mengetahui apakah server model JetStream sudah siap: Workload diberi nama `jetstream-pathways` di YAML sebelumnya, dan `0` adalah node head.
          kubectl logs -f jetstream-pathways-0 -c jax-tpu
          
    Outputnya mirip dengan berikut yang menunjukkan bahwa server model JetStream siap melayani permintaan:
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
        

Hubungkan ke server model

Anda dapat mengakses deployment JetStream Pathways menggunakan layanan ClusterIP GKE. Layanan ClusterIP hanya dapat dijangkau dari dalam cluster. Oleh karena itu, untuk mengakses layanan dari luar cluster, Anda harus membuat sesi penerusan port terlebih dahulu dengan menjalankan perintah berikut:

kubectl port-forward pod/${HEAD_POD} 8000:8000

Pastikan Anda dapat mengakses server HTTP JetStream dengan membuka terminal baru dan menjalankan perintah berikut:

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

Permintaan awal dapat memerlukan waktu beberapa detik untuk diselesaikan karena pemanasan model. Outputnya akan mirip dengan berikut ini:

{
    "response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}

Inferensi yang diuraikan

Penayangan yang tidak digabungkan adalah teknik untuk menjalankan model bahasa besar (LLM) yang memisahkan tahap pengisian awal dan dekode ke dalam proses yang berbeda, yang berpotensi berjalan di mesin yang berbeda. Hal ini memungkinkan pemanfaatan resource yang lebih baik dan dapat meningkatkan performa dan efisiensi, terutama untuk model besar.

  • Prefill: tahap ini memproses perintah input dan menghasilkan representasi perantara (seperti cache key-value). Proses ini sering kali membutuhkan banyak komputasi.
  • Dekode: tahap ini menghasilkan token output, satu per satu, menggunakan representasi pra-pengisian. Biasanya terikat dengan bandwidth memori.

Dengan memisahkan tahap ini, penayangan yang diuraikan memungkinkan pra-pengisian dan decoding berjalan secara paralel, sehingga meningkatkan throughput dan latensi.

Untuk mengaktifkan penayangan yang dipisah-pisahkan, ubah YAML berikut untuk menggunakan dua slice v6e-8: satu untuk pengisian otomatis dan yang lainnya untuk pembuatan. Sebelum melanjutkan, pastikan cluster GKE Anda memiliki setidaknya dua node pool yang dikonfigurasi dengan topologi v6e-8 ini. Untuk performa yang optimal, flag XLA tertentu telah dikonfigurasi.

Buat checkpoint llama2-70b dengan mengikuti proses yang sama seperti pembuatan checkpoint llama3.1-405b, yang dijelaskan di bagian sebelumnya.

  1. Untuk meluncurkan server JetStream dalam mode yang tidak digabungkan menggunakan Pathways, salin dan tempel YAML berikut ke dalam file bernama pathways-job.yaml:
    apiVersion: leaderworkerset.x-k8s.io/v1
    kind: LeaderWorkerSet
    metadata:
      name: jetstream-pathways
      annotations:
        leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      replicas: 1
      leaderWorkerTemplate:
        subGroupPolicy:
          subGroupSize: 2
        leaderTemplate:
          metadata:
            labels:
              app: jetstream-pathways
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            tolerations:
            - key: "google.com/tpu"
              operator: "Exists"
              effect: "NoSchedule"
            containers:
            - name: pathways-proxy
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
              args:
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --server_port=38681
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              imagePullPolicy: Always
              ports:
              - containerPort: 38681
            - name: pathways-rm
              env:       
              - name: HOST_ADDRESS
                value: "$(LWS_LEADER_ADDRESS)"
              - name: TPU_SKIP_MDS_QUERY
                value: "true"
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              args:
              - --server_port=38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              - --node_type=resource_manager
              - --instance_count=2
              - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4
              imagePullPolicy: Always
              ports:
              - containerPort: 38677
            - name: jax-tpu
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
              # Optimized settings used to serve Llama2-70b.
              args:
              - MaxText/configs/base.yml
              - tokenizer_path=assets/tokenizer.llama2
              - load_parameters_path=GCS_CHECKPOINT_PATH
              - max_prefill_predict_length=1024
              - max_target_length=2048
              - model_name=llama2-70b
              - ici_fsdp_parallelism=1
              - ici_autoregressive_parallelism=1
              - ici_tensor_parallelism=-1
              - scan_layers=false
              - weight_dtype=bfloat16
              - per_device_batch_size=1
              - checkpoint_is_quantized=true 
              - quantization=int8
              - quantize_kvcache=true
              - compute_axis_order=0,2,1,3
              - ar_cache_axis_order=0,2,1,3
              - stack_prefill_result_cache=True
              # Specify disaggregated mode to run Jetstream
              - inference_server=ExperimentalMaxtextDisaggregatedServer_8
              - inference_benchmark_test=True
              - enable_model_warmup=True
              env:
              - name: LOG_LEVEL
                value: "INFO"
              imagePullPolicy: Always
              securityContext:
                capabilities:
                  add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"]
              ports: 
              - containerPort: 9000
              startupProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 1
                initialDelaySeconds: 240
                failureThreshold: 10000
              livenessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
              readinessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
            - name: jetstream-http
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
              imagePullPolicy: Always
              ports:
              - containerPort: 8000
        # The size variable defines the number of worker nodes to be created.
        # It must be equal to the number of hosts + 1 (for the leader node).
        size: 5
        workerTemplate:
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            containers:
            - name: worker
              args:
              - --server_port=38679
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              imagePullPolicy: Always
              ports:
              - containerPort: 38679
              resources:
                limits:
                  google.com/tpu: "4"
    --- 
    apiVersion: v1
    kind: Service
    metadata:
      name: jetstream-svc
    spec:
      selector:
        app: jetstream-pathways
      ports:
      - protocol: TCP
        name: jetstream-http
        port: 8000
        targetPort: 8000
      
    Ganti kode berikut:
    • TPU_ACCELERATOR_TYPE: Jenis akselerator TPU. Contoh, tpu-v6e-slice.
    • TPU_TOPOLOGY: Topologi TPU. Contoh, 2x4.
    • GCS_CHECKPOINT_PATH: Jalur GCS ke titik pemeriksaan.
  2. Terapkan YAML ini, server model akan memerlukan waktu beberapa saat untuk memulihkan checkpoint. Untuk model 70B, proses ini dapat memakan waktu ~2 menit.
      kubectl apply -f pathways-job.yaml
          
  3. Lihat log Kubernetes untuk melihat apakah server model JetStream sudah siap:
        kubectl logs -f jetstream-pathways-0 -c jax-tpu
        
    Anda akan melihat output yang mirip dengan berikut yang menunjukkan bahwa server model JetStream siap melayani permintaan:
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
      

Hubungkan ke server model

Anda dapat mengakses deployment JetStream Pathways melalui layanan ClusterIP GKE. Layanan ClusterIP hanya dapat dijangkau dari dalam cluster. Oleh karena itu, untuk mengakses layanan dari luar cluster, buat sesi penerusan port dengan menjalankan perintah berikut:

kubectl port-forward pod/${HEAD_POD} 8000:8000

Pastikan Anda dapat mengakses server HTTP JetStream dengan membuka terminal baru dan menjalankan perintah berikut:

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

Permintaan awal dapat memerlukan waktu beberapa detik untuk diselesaikan karena pemanasan model. Outputnya akan mirip dengan berikut ini:

{
    "response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}

Langkah berikutnya