Inferensi multihost adalah metode menjalankan inferensi model yang mendistribusikan model di beberapa host akselerator. Hal ini memungkinkan inferensi model besar yang tidak sesuai dengan satu host. Jalur dapat di-deploy untuk kasus penggunaan inferensi multihost batch dan real-time.
Sebelum memulai
Pastikan Anda memiliki:
- Membuat cluster GKE yang menggunakan chip Trillium (v6e-16).
- Alat Kubernetes yang terinstal
- Mengaktifkan TPU API
- Mengaktifkan Google Kubernetes Engine API
Menjalankan Inferensi batch menggunakan JetStream
JetStream adalah mesin yang dioptimalkan untuk throughput dan memori untuk inferensi model bahasa besar (LLM) di perangkat XLA, terutama Tensor Processing Unit (TPU) yang ditulis dalam JAX.
Anda dapat menggunakan image Docker JetStream yang telah dibuat sebelumnya untuk menjalankan beban kerja inferensi batch,
seperti yang ditunjukkan dalam YAML berikut. Container ini dibangun dari
project JetStream OSS.
Untuk mengetahui informasi selengkapnya tentang tanda MaxText-JetStream, lihat
Tanda server MaxText JetStream.
Contoh berikut menggunakan chip Trillium (v6e-16) untuk memuat checkpoint int8 Llama3.1-405b dan melakukan inferensi di atasnya. Contoh ini mengasumsikan bahwa Anda sudah
memiliki cluster GKE dengan setidaknya satu kumpulan node v6e-16 di dalamnya.
Mulai server model dan Pathways
- Dapatkan kredensial ke cluster dan tambahkan ke konteks kubectl lokal Anda.
gcloud container clusters get-credentials $CLUSTER \ --zone=$ZONE \ --project=$PROJECT \ && kubectl config set-context --current --namespace=default
- Men-deploy LeaderWorkerSet (LWS) API.
VERSION=v0.4.0 kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
- Salin dan tempel YAML berikut ke dalam file bernama
pathways-job.yaml: YAML ini telah dioptimalkan untuk bentuk irisanv6e-16. Untuk mengetahui informasi selengkapnya tentang cara mengonversi checkpoint Meta menjadi checkpoint yang kompatibel dengan JAX, ikuti panduan pembuatan checkpoint di Membuat checkpoint inferensi. Sebagai contoh, petunjuk untuk Llama3.1-405B disediakan di sini Konversi titik pemeriksaan untuk Llama3.1-405B. Ganti kode berikut:apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=1 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4 - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama3.1-405b. args: - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml - model_name=llama3.1-405b - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - async_checkpointing=false - steps=1 - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=2 - ici_tensor_parallelism=8 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=10 - enable_single_controller=true - quantization=int8 - quantize_kvcache=true - checkpoint_is_quantized=true - enable_model_warmup=true imagePullPolicy: Always ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 900 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE: Jenis akselerator TPU. Contoh,tpu-v6e-slice.TPU_TOPOLOGY: Topologi TPU. Contoh,2x4.GCS_CHECKPOINT_PATH: Jalur GCS ke titik pemeriksaan.
- Lihat log Kubernetes untuk mengetahui apakah server model JetStream sudah siap:
Workload diberi nama `jetstream-pathways` di YAML sebelumnya, dan `0`
adalah node head.
Outputnya mirip dengan berikut yang menunjukkan bahwa server model JetStream siap melayani permintaan:kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
Hubungkan ke server model
Anda dapat mengakses deployment JetStream Pathways menggunakan layanan ClusterIP GKE. Layanan ClusterIP hanya dapat dijangkau dari dalam cluster. Oleh karena itu, untuk mengakses layanan dari luar cluster, Anda harus membuat sesi penerusan port terlebih dahulu dengan menjalankan perintah berikut:
kubectl port-forward pod/${HEAD_POD} 8000:8000
Pastikan Anda dapat mengakses server HTTP JetStream dengan membuka terminal baru dan menjalankan perintah berikut:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
Permintaan awal dapat memerlukan waktu beberapa detik untuk diselesaikan karena pemanasan model. Outputnya akan mirip dengan berikut ini:
{
"response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}
Inferensi yang diuraikan
Penayangan yang tidak digabungkan adalah teknik untuk menjalankan model bahasa besar (LLM) yang memisahkan tahap pengisian awal dan dekode ke dalam proses yang berbeda, yang berpotensi berjalan di mesin yang berbeda. Hal ini memungkinkan pemanfaatan resource yang lebih baik dan dapat meningkatkan performa dan efisiensi, terutama untuk model besar.
- Prefill: tahap ini memproses perintah input dan menghasilkan representasi perantara (seperti cache key-value). Proses ini sering kali membutuhkan banyak komputasi.
- Dekode: tahap ini menghasilkan token output, satu per satu, menggunakan representasi pra-pengisian. Biasanya terikat dengan bandwidth memori.
Dengan memisahkan tahap ini, penayangan yang diuraikan memungkinkan pra-pengisian dan decoding berjalan secara paralel, sehingga meningkatkan throughput dan latensi.
Untuk mengaktifkan penayangan yang dipisah-pisahkan, ubah YAML berikut untuk menggunakan dua slice v6e-8: satu untuk pengisian otomatis dan yang lainnya untuk pembuatan. Sebelum melanjutkan, pastikan cluster GKE Anda memiliki setidaknya dua node pool yang dikonfigurasi dengan topologi v6e-8 ini.
Untuk performa yang optimal, flag XLA tertentu telah dikonfigurasi.
Buat checkpoint llama2-70b dengan mengikuti proses yang sama seperti pembuatan checkpoint llama3.1-405b, yang dijelaskan di bagian sebelumnya.
- Untuk meluncurkan server JetStream dalam mode yang tidak digabungkan menggunakan Pathways, salin
dan tempel YAML berikut ke dalam file bernama
pathways-job.yaml: Ganti kode berikut:apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: subGroupPolicy: subGroupSize: 2 leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=2 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4 imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama2-70b. args: - MaxText/configs/base.yml - tokenizer_path=assets/tokenizer.llama2 - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - model_name=llama2-70b - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=1 - ici_tensor_parallelism=-1 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=1 - checkpoint_is_quantized=true - quantization=int8 - quantize_kvcache=true - compute_axis_order=0,2,1,3 - ar_cache_axis_order=0,2,1,3 - stack_prefill_result_cache=True # Specify disaggregated mode to run Jetstream - inference_server=ExperimentalMaxtextDisaggregatedServer_8 - inference_benchmark_test=True - enable_model_warmup=True env: - name: LOG_LEVEL value: "INFO" imagePullPolicy: Always securityContext: capabilities: add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"] ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 240 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE: Jenis akselerator TPU. Contoh,tpu-v6e-slice.TPU_TOPOLOGY: Topologi TPU. Contoh,2x4.GCS_CHECKPOINT_PATH: Jalur GCS ke titik pemeriksaan.
- Terapkan YAML ini, server model akan memerlukan waktu beberapa saat untuk memulihkan
checkpoint. Untuk model 70B, proses ini dapat memakan waktu ~2 menit.
kubectl apply -f pathways-job.yaml
- Lihat log Kubernetes untuk melihat apakah server model JetStream sudah siap:
Anda akan melihat output yang mirip dengan berikut yang menunjukkan bahwa server model JetStream siap melayani permintaan:kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
Hubungkan ke server model
Anda dapat mengakses deployment JetStream Pathways melalui layanan ClusterIP GKE. Layanan ClusterIP hanya dapat dijangkau dari dalam cluster. Oleh karena itu, untuk mengakses layanan dari luar cluster, buat sesi penerusan port dengan menjalankan perintah berikut:
kubectl port-forward pod/${HEAD_POD} 8000:8000
Pastikan Anda dapat mengakses server HTTP JetStream dengan membuka terminal baru dan menjalankan perintah berikut:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
Permintaan awal dapat memerlukan waktu beberapa detik untuk diselesaikan karena pemanasan model. Outputnya akan mirip dengan berikut ini:
{
"response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}
Langkah berikutnya
- Batch workload dengan Pathways
- Mode interaktif jalur
- Memindahkan beban kerja JAX ke Pathways
- Pelatihan yang tangguh dengan Pathways
- Pemecahan Masalah Jalur