多主機推論是一種模型推論執行方法,可將模型分配到多個加速器主機。這樣一來,即使大型模型無法在單一主機上執行,也能進行推論。您可以部署 Pathways,用於批次和即時多主機推論用途。
事前準備
請確認您已備妥以下項目:
使用 JetStream 執行批次推論
JetStream 是處理量和記憶體最佳化引擎,適用於 XLA 裝置 (主要是以 JAX 編寫的 Tensor Processing Unit (TPU)) 的大型語言模型 (LLM) 推論作業。
您可以按照下列 YAML 所示,使用預先建構的 JetStream Docker 映像檔執行批次推論工作負載。這個容器是以 OSS JetStream 專案為基礎建構而成。如要進一步瞭解 MaxText-JetStream 旗標,請參閱「JetStream MaxText 伺服器旗標」。以下範例使用 Trillium 晶片 (v6e-16) 載入 Llama3.1-405b int8 檢查點,並對其執行推論。本範例假設您已擁有 GKE 叢集,且叢集內至少有一個 v6e-16 節點集區。
啟動模型伺服器和 Pathways
- 取得叢集憑證,並新增至本機 kubectl 環境。
gcloud container clusters get-credentials $CLUSTER \ --zone=$ZONE \ --project=$PROJECT \ && kubectl config set-context --current --namespace=default
- 部署 LeaderWorkerSet (LWS) API。
VERSION=v0.4.0 kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
- 複製下列 YAML 並貼到名為
pathways-job.yaml的檔案中: 這個 YAML 已針對v6e-16切片形狀進行最佳化。如要進一步瞭解如何將 Meta 檢查點轉換為 JAX 相容檢查點,請按照「建立推論檢查點」中的檢查點建立指南操作。舉例來說,如需 Llama3.1-405B 的操作說明,請參閱「Llama3.1-405B 的檢查點轉換」。 更改下列內容:apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=1 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4 - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama3.1-405b. args: - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml - model_name=llama3.1-405b - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - async_checkpointing=false - steps=1 - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=2 - ici_tensor_parallelism=8 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=10 - enable_single_controller=true - quantization=int8 - quantize_kvcache=true - checkpoint_is_quantized=true - enable_model_warmup=true imagePullPolicy: Always ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 900 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE:TPU 加速器類型。例如:tpu-v6e-slice。TPU_TOPOLOGY:TPU 拓撲。例如:2x4。GCS_CHECKPOINT_PATH:檢查點的 GCS 路徑。
- 查看 Kubernetes 記錄,確認 JetStream 模型伺服器是否已準備就緒:
先前 YAML 中的工作負載名稱為 `jetstream-pathways`,而 `0`
是主要節點。
輸出內容會與以下所示內容類似,表示 JetStream 模型伺服器已準備好處理要求:kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
連線至模型伺服器
您可以使用 GKE 的 ClusterIP 服務存取 JetStream Pathways 部署作業。ClusterIP 服務只能從叢集內部存取。因此,如要從叢集外部存取服務,您必須先執行下列指令,建立通訊埠轉送工作階段:
kubectl port-forward pod/${HEAD_POD} 8000:8000
開啟新的終端機並執行下列指令,確認您可以存取 JetStream HTTP 伺服器:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
由於模型暖機,初始要求可能需要幾秒鐘才能完成。 畫面會顯示如下的輸出內容:
{
"response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}
不匯總的推論
解聚服務是一種執行大型語言模型 (LLM) 的技術,可將預填和解碼階段分成不同程序,可能在不同機器上執行。這有助於更有效率地運用資源,並提升效能和效率,特別是大型模型。
- 預先填充:這個階段會處理輸入提示,並產生中繼表示法 (例如鍵值快取)。這類工作通常需要大量運算資源。
- 解碼:這個階段會使用預先填入的表示法,逐一生成輸出權杖。這通常會受到記憶體頻寬限制。
透過分離這些階段,不聚合的服務可讓預先填入和解碼作業平行執行,進而提升輸送量並縮短延遲。
如要啟用分離式服務,請修改下列 YAML,使用兩個 v6e-8 切片:一個用於預先填入,另一個用於生成。繼續操作前,請確認 GKE 叢集已設定至少兩個節點集區,且採用 v6e-8 拓撲。為獲得最佳效能,系統已設定特定 XLA 旗標。
按照與建立 llama3.1-405b 查核點相同的程序,建立 llama2-70b 查核點。詳情請參閱上一節。
- 如要使用 Pathways 以分離模式啟動 JetStream 伺服器,請將下列 YAML 複製並貼到名為
pathways-job.yaml的檔案中: 更改下列內容:apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: subGroupPolicy: subGroupSize: 2 leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=2 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4 imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama2-70b. args: - MaxText/configs/base.yml - tokenizer_path=assets/tokenizer.llama2 - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - model_name=llama2-70b - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=1 - ici_tensor_parallelism=-1 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=1 - checkpoint_is_quantized=true - quantization=int8 - quantize_kvcache=true - compute_axis_order=0,2,1,3 - ar_cache_axis_order=0,2,1,3 - stack_prefill_result_cache=True # Specify disaggregated mode to run Jetstream - inference_server=ExperimentalMaxtextDisaggregatedServer_8 - inference_benchmark_test=True - enable_model_warmup=True env: - name: LOG_LEVEL value: "INFO" imagePullPolicy: Always securityContext: capabilities: add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"] ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 240 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE:TPU 加速器類型。例如:tpu-v6e-slice。TPU_TOPOLOGY:TPU 拓撲。例如:2x4。GCS_CHECKPOINT_PATH:檢查點的 GCS 路徑。
- 套用這個 YAML 後,模型伺服器需要一段時間才能還原檢查點。如果是 70B 模型,這項作業可能需要約 2 分鐘。
kubectl apply -f pathways-job.yaml
- 查看 Kubernetes 記錄,確認 JetStream 模型伺服器是否已準備就緒:
您會看到類似以下的輸出內容,表示 JetStream 模型伺服器已準備好處理要求:kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
連線至模型伺服器
您可以透過 GKE 的 ClusterIP 服務存取 JetStream Pathways 部署作業。ClusterIP 服務只能從叢集內部存取。 因此,如要從叢集外部存取服務,請執行下列指令,建立連接埠轉送工作階段:
kubectl port-forward pod/${HEAD_POD} 8000:8000
開啟新的終端機並執行下列指令,確認您可以存取 JetStream HTTP 伺服器:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
由於模型暖機,初始要求可能需要幾秒鐘才能完成。 畫面會顯示如下的輸出內容:
{
"response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}