Inferensi multihost adalah metode menjalankan inferensi model yang mendistribusikan model ke beberapa host akselerator. Hal ini memungkinkan inferensi model besar yang tidak sesuai dengan satu host. Pathways dapat di-deploy untuk kasus penggunaan inferensi multihost batch dan real time.
Sebelum memulai
Pastikan Anda memiliki:
- Membuat cluster GKE yang menggunakan chip Trillium (v6e-16).
- Menginstal alat Kubernetes
- Mengaktifkan Google Kubernetes Engine API
Menjalankan inferensi Batch menggunakan JetStream
JetStream adalah mesin yang dioptimalkan untuk throughput dan memori untuk inferensi model bahasa besar (LLM) pada perangkat XLA, terutama Tensor Processing Unit (TPU) yang ditulis dalam JAX.
Anda dapat menggunakan image Docker JetStream yang telah dibuat sebelumnya untuk menjalankan workload inferensi batch, seperti yang ditunjukkan dalam YAML berikut. Container ini dibuat dari project
OSS JetStream.
Untuk mengetahui informasi selengkapnya tentang flag MaxText-JetStream, lihat
Flag server JetStream MaxText.
Contoh berikut menggunakan chip Trillium (v6e-16) untuk memuat checkpoint int8 Llama3.1-405b dan melakukan inferensi di atasnya. Contoh ini mengasumsikan bahwa Anda sudah memiliki cluster GKE dengan setidaknya satu nodepool v6e-16 di dalamnya.
Memulai server model dan Pathways
- Dapatkan kredensial ke cluster dan tambahkan ke konteks kubectl lokal Anda.
gcloud container clusters get-credentials $CLUSTER \ --zone=$ZONE \ --project=$PROJECT \ && kubectl config set-context --current --namespace=default
- Deploy LeaderWorkerSet (LWS) API.
VERSION=v0.4.0 kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
- Salin dan tempel YAML berikut ke dalam file bernama
pathways-job.yaml: YAML ini telah dioptimalkan untuk bentuk slicev6e-16. Untuk mengetahui informasi selengkapnya tentang cara mengonversi checkpoint Meta menjadi checkpoint yang kompatibel dengan JAX, ikuti panduan pembuatan checkpoint di Membuat checkpoint inferensi. Sebagai contoh, petunjuk untuk Llama3.1-405B disediakan di sini Konversi checkpoint untuk Llama3.1-405B. Ganti hal berikut:apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=1 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4 - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama3.1-405b. args: - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml - model_name=llama3.1-405b - load_parameters_path=CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - async_checkpointing=false - steps=1 - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=2 - ici_tensor_parallelism=8 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=10 - enable_single_controller=true - quantization=int8 - quantize_kvcache=true - checkpoint_is_quantized=true - enable_model_warmup=true imagePullPolicy: Always ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 900 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE: Jenis akselerator TPU. Misalnya,tpu-v6e-slice.TPU_TOPOLOGY: Topologi TPU. Misalnya,2x4.CHECKPOINT_PATH: Jalur Cloud Storage ke checkpoint.
- Lihat log Kubernetes untuk melihat apakah server model JetStream sudah siap:
Workload diberi nama `jetstream-pathways` di YAML sebelumnya, dan `0`
adalah node utama.
Outputnya mirip dengan berikut ini yang menunjukkan bahwa server model JetStream siap untuk menayangkan permintaan:kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
Menghubungkan ke server model
Anda dapat mengakses deployment JetStream Pathways menggunakan layanan ClusterIP GKE. Layanan ClusterIP hanya dapat dijangkau dari dalam cluster. Oleh karena itu, untuk mengakses layanan dari luar cluster, Anda harus terlebih dahulu membuat sesi penerusan port dengan menjalankan perintah berikut:
kubectl port-forward pod/${HEAD_POD} 8000:8000
Pastikan Anda dapat mengakses server HTTP JetStream dengan membuka terminal baru dan menjalankan perintah berikut:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
Permintaan awal dapat memerlukan waktu beberapa detik untuk diselesaikan karena pemanasan model. Outputnya akan mirip dengan berikut ini:
{
"response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}
Inferensi terpisah
Inferensi terpisah adalah teknik untuk menjalankan model bahasa besar (LLM) yang memisahkan tahap prefill dan decode ke dalam proses yang berbeda, yang berpotensi berada di mesin yang berbeda. Hal ini memungkinkan penggunaan resource yang lebih baik dan dapat meningkatkan performa dan efisiensi, terutama untuk model besar.
- Prefill: tahap ini memproses perintah input dan menghasilkan representasi perantara (seperti cache nilai kunci). Biasanya memerlukan banyak komputasi.
- Decode: tahap ini menghasilkan token output, satu per satu, menggunakan representasi prefill. Biasanya terikat pada bandwidth memori.
Dengan memisahkan tahap ini, inferensi terpisah memungkinkan prefill dan decode berjalan secara paralel, sehingga meningkatkan throughput dan latensi.
Untuk mengaktifkan inferensi terpisah, ubah YAML berikut untuk menggunakan dua slice v6e-8: satu untuk prefill dan yang lainnya untuk pembuatan. Sebelum melanjutkan, pastikan cluster GKE Anda memiliki setidaknya dua node pool yang dikonfigurasi dengan topologi v6e-8 ini.
Untuk performa yang optimal, flag XLA tertentu telah dikonfigurasi.
Buat checkpoint llama2-70b dengan mengikuti proses yang sama seperti pembuatan checkpoint llama3.1-405b, yang dijelaskan di bagian sebelumnya.
- Untuk meluncurkan server JetStream dalam mode terpisah menggunakan Pathways, salin
dan tempel YAML berikut ke dalam file bernama
pathways-job.yaml: Ganti hal berikut:apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: subGroupPolicy: subGroupSize: 2 leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=2 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4 imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama2-70b. args: - MaxText/configs/base.yml - tokenizer_path=assets/tokenizer.llama2 - load_parameters_path=CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - model_name=llama2-70b - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=1 - ici_tensor_parallelism=-1 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=1 - checkpoint_is_quantized=true - quantization=int8 - quantize_kvcache=true - compute_axis_order=0,2,1,3 - ar_cache_axis_order=0,2,1,3 - stack_prefill_result_cache=True # Specify disaggregated mode to run Jetstream - inference_server=ExperimentalMaxtextDisaggregatedServer_8 - inference_benchmark_test=True - enable_model_warmup=True env: - name: LOG_LEVEL value: "INFO" imagePullPolicy: Always securityContext: capabilities: add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"] ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 240 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE: Jenis akselerator TPU. Misalnya,tpu-v6e-slice.TPU_TOPOLOGY: Topologi TPU. Misalnya,2x4.CHECKPOINT_PATH: Jalur Cloud Storage ke checkpoint.
- Terapkan YAML ini, server model akan memerlukan waktu beberapa saat untuk memulihkan
checkpoint. Untuk model 70B, hal ini mungkin memerlukan waktu sekitar 2 menit.
kubectl apply -f pathways-job.yaml
- Lihat log Kubernetes untuk melihat apakah server model JetStream sudah siap:
Anda akan melihat output yang mirip dengan berikut ini yang menunjukkan bahwa server model JetStream siap untuk menayangkan permintaan:kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
Menghubungkan ke server model
Anda dapat mengakses deployment JetStream Pathways melalui layanan ClusterIP GKE. Layanan ClusterIP hanya dapat dijangkau dari dalam cluster. Oleh karena itu, untuk mengakses layanan dari luar cluster, buat sesi penerusan port dengan menjalankan perintah berikut:
kubectl port-forward pod/${HEAD_POD} 8000:8000
Pastikan Anda dapat mengakses server HTTP JetStream dengan membuka terminal baru dan menjalankan perintah berikut:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
Permintaan awal dapat memerlukan waktu beberapa detik untuk diselesaikan karena pemanasan model. Outputnya akan mirip dengan berikut ini:
{
"response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}
Langkah berikutnya
- Workload batch dengan Pathways
- Mode interaktif Pathways
- Memindahkan workload JAX ke Pathways
- Pelatihan yang tangguh dengan Pathways
- Memecahkan masalah Pathways