多主机推理是一种运行模型推理的方法,可将模型分布在多个加速器主机上。这样一来,便可推理无法在单个主机上运行的大型模型。Pathways 可用于批量和实时多主机推理用例。
准备工作
请确保您已备妥:
- 创建了使用 Trillium 芯片 (v6e-16) 的 GKE 集群。
- 已安装的 Kubernetes 工具
- 已启用 TPU API
- 已启用 Google Kubernetes Engine API
使用 JetStream 运行批量推理
JetStream 是一种吞吐量和内存优化引擎,用于在 XLA 设备(主要是张量处理单元 [TPU])上进行大语言模型 (LLM) 推理,采用 JAX 编写。
您可以使用预构建的 JetStream Docker 映像来运行批处理推理工作负载,如以下 YAML 所示。此容器基于 OSS JetStream 项目构建。
如需详细了解 MaxText-JetStream 标志,请参阅 JetStream MaxText 服务器标志。以下示例使用 Trillium 芯片 (v6e-16) 加载 Llama3.1-405b int8 检查点并对其执行推理。此示例假定您已有一个 GKE 集群,其中至少包含一个 v6e-16 节点池。
启动模型服务器和 Pathways
- 获取集群的凭据并将其添加到本地 kubectl 上下文中。
gcloud container clusters get-credentials $CLUSTER \ --zone=$ZONE \ --project=$PROJECT \ && kubectl config set-context --current --namespace=default
- 部署 LeaderWorkerSet (LWS) API。
VERSION=v0.4.0 kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
- 将以下 YAML 复制并粘贴到名为
pathways-job.yaml的文件中: 此 YAML 已针对v6e-16切片形状进行了优化。如需详细了解如何将 Meta 检查点转换为与 JAX 兼容的检查点,请按照创建推理检查点中的检查点创建指南进行操作。 例如,此处提供了 Llama3.1-405B 的相关说明:Llama3.1-405B 的检查点转换。 替换以下内容:apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=1 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4 - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama3.1-405b. args: - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml - model_name=llama3.1-405b - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - async_checkpointing=false - steps=1 - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=2 - ici_tensor_parallelism=8 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=10 - enable_single_controller=true - quantization=int8 - quantize_kvcache=true - checkpoint_is_quantized=true - enable_model_warmup=true imagePullPolicy: Always ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 900 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE:TPU 加速器类型。例如tpu-v6e-slice。TPU_TOPOLOGY:TPU 拓扑。例如2x4。GCS_CHECKPOINT_PATH:检查点的 GCS 路径。
- 查看 Kubernetes 日志,了解 JetStream 模型服务器是否已准备就绪:
在之前的 YAML 中,工作负载名为“jetstream-pathways”,而“0”是头节点。
输出类似于以下内容,表明 JetStream 模型服务器已准备好处理请求:kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
连接到模型服务器
您可以使用 GKE 的 ClusterIP 服务访问 JetStream Pathways 部署。只能从集群内部访问 ClusterIP 服务。因此,如需从集群外部访问该服务,您必须先运行以下命令来建立端口转发会话:
kubectl port-forward pod/${HEAD_POD} 8000:8000
通过打开新终端并运行以下命令,验证您是否可以访问 JetStream HTTP 服务器:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
由于模型预热,初始请求可能需要几秒钟才能完成。 输出应类似如下所示:
{
"response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}
分离式推理
分离式服务是一种运行大型语言模型 (LLM) 的技术,它将预填充和解码阶段分离到不同的进程中,这些进程可能在不同的机器上运行。这样可以更好地利用资源,并有助于提高性能和效率,尤其对于大型模型而言。
- 预填充:此阶段处理输入提示并生成中间表示形式(如键值对缓存)。通常需要大量计算。
- 解码:此阶段使用预填充表示法逐个生成输出 token。它通常受内存带宽限制。
通过分离这些阶段,分离式服务可并行运行预填充和解码,从而提高吞吐量并缩短延迟时间。
如需启用分离式服务,请修改以下 YAML 以利用两个 v6e-8 切片:一个用于预填充,另一个用于生成。在继续操作之前,请确保您的 GKE 集群至少有两个配置了 v6e-8 拓扑的节点池。为了获得最佳性能,我们配置了特定的 XLA 标志。
按照与创建 llama3.1-405b 检查点相同的流程(详见上一部分)创建 llama2-70b 检查点。
- 如需使用 Pathways 在分离模式下启动 JetStream 服务器,请将以下 YAML 复制并粘贴到名为
pathways-job.yaml的文件中: 替换以下内容:apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: subGroupPolicy: subGroupSize: 2 leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=2 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4 imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama2-70b. args: - MaxText/configs/base.yml - tokenizer_path=assets/tokenizer.llama2 - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - model_name=llama2-70b - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=1 - ici_tensor_parallelism=-1 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=1 - checkpoint_is_quantized=true - quantization=int8 - quantize_kvcache=true - compute_axis_order=0,2,1,3 - ar_cache_axis_order=0,2,1,3 - stack_prefill_result_cache=True # Specify disaggregated mode to run Jetstream - inference_server=ExperimentalMaxtextDisaggregatedServer_8 - inference_benchmark_test=True - enable_model_warmup=True env: - name: LOG_LEVEL value: "INFO" imagePullPolicy: Always securityContext: capabilities: add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"] ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 240 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE:TPU 加速器类型。例如tpu-v6e-slice。TPU_TOPOLOGY:TPU 拓扑。例如2x4。GCS_CHECKPOINT_PATH:检查点的 GCS 路径。
- 应用此 YAML 后,模型服务器需要一些时间才能恢复检查点。对于 70B 模型,这可能需要大约 2 分钟。
kubectl apply -f pathways-job.yaml
- 查看 Kubernetes 日志,了解 JetStream 模型服务器是否已准备就绪:
您将看到类似于以下内容的输出,这表示 JetStream 模型服务器已准备好处理请求:kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
连接到模型服务器
您可以通过 GKE 的 ClusterIP 服务访问 JetStream Pathways 部署。只能从集群内部访问 ClusterIP 服务。因此,如需从集群外部访问该服务,请运行以下命令来建立端口转发会话:
kubectl port-forward pod/${HEAD_POD} 8000:8000
通过打开新终端并运行以下命令,验证您是否可以访问 JetStream HTTP 服务器:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
由于模型预热,初始请求可能需要几秒钟才能完成。 输出应类似如下所示:
{
"response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}