멀티 호스트 추론은 여러 가속기 호스트에 모델을 분산하는 모델 추론을 실행하는 방법입니다. 이를 통해 단일 호스트에 맞지 않는 대규모 모델의 추론이 가능합니다. 경로는 일괄 및 실시간 멀티 호스트 추론 사용 사례에 모두 배포할 수 있습니다.
시작하기 전에
다음 사항이 필요합니다.
- Trillium 칩 (v6e-16)을 사용하는 GKE 클러스터를 만들었습니다.
- 설치된 Kubernetes 도구
- TPU API 사용 설정
- Google Kubernetes Engine API 사용 설정
JetStream을 사용하여 일괄 추론 실행
JetStream은 JAX로 작성된 XLA 기기(주로 텐서 처리 장치(TPU))에서 대규모 언어 모델(LLM) 추론을 위한 처리량 및 메모리 최적화 엔진입니다.
다음 YAML과 같이 사전 빌드된 JetStream Docker 이미지를 사용하여 일괄 추론 워크로드를 실행할 수 있습니다. 이 컨테이너는 OSS JetStream 프로젝트에서 빌드되었습니다.
MaxText-JetStream 플래그에 관한 자세한 내용은 JetStream MaxText 서버 플래그를 참고하세요.
다음 예에서는 Trillium 칩 (v6e-16)을 사용하여 Llama3.1-405b int8 체크포인트를 로드하고 추론을 실행합니다. 이 예시에서는 이미 GKE 클러스터 내에 하나 이상의 v6e-16 노드 풀이 있다고 가정합니다.
모델 서버 및 Pathways 시작
- 클러스터의 사용자 인증 정보를 가져와 로컬 kubectl 컨텍스트에 추가합니다.
gcloud container clusters get-credentials $CLUSTER \ --zone=$ZONE \ --project=$PROJECT \ && kubectl config set-context --current --namespace=default
- LeaderWorkerSet (LWS) API를 배포합니다.
VERSION=v0.4.0 kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
- 다음 YAML을
pathways-job.yaml이라는 파일에 복사하여 붙여넣습니다. 이 YAML은v6e-16슬라이스 모양에 맞게 최적화되었습니다. Meta 체크포인트를 JAX 호환 체크포인트로 변환하는 방법에 관한 자세한 내용은 추론 체크포인트 만들기의 체크포인트 생성 가이드를 참고하세요. 예를 들어 Llama3.1-405B에 관한 안내는 여기 Llama3.1-405B 체크포인트 변환에 제공되어 있습니다. 다음을 바꿉니다.apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=1 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4 - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama3.1-405b. args: - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml - model_name=llama3.1-405b - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - async_checkpointing=false - steps=1 - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=2 - ici_tensor_parallelism=8 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=10 - enable_single_controller=true - quantization=int8 - quantize_kvcache=true - checkpoint_is_quantized=true - enable_model_warmup=true imagePullPolicy: Always ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 900 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE: TPU 가속기 유형입니다. 예를 들면tpu-v6e-slice입니다.TPU_TOPOLOGY: TPU 토폴로지입니다. 예를 들면2x4입니다.GCS_CHECKPOINT_PATH: 체크포인트의 GCS 경로입니다.
- Kubernetes 로그를 확인하여 JetStream 모델 서버가 준비되었는지 확인합니다.
이전 YAML에서 워크로드의 이름은 `jetstream-pathways` 이고 `0`은 헤드 노드입니다.
출력은 다음과 비슷하며, 이는 JetStream 모델 서버가 요청을 처리할 준비가 되었음을 나타냅니다.kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
모델 서버에 연결
GKE의 ClusterIP 서비스를 사용하여 JetStream Pathways 배포에 액세스할 수 있습니다. ClusterIP 서비스는 클러스터 내에서만 연결할 수 있습니다. 따라서 클러스터 외부에서 서비스에 액세스하려면 먼저 다음 명령어를 실행하여 포트 전달 세션을 설정해야 합니다.
kubectl port-forward pod/${HEAD_POD} 8000:8000
새 터미널을 열고 다음 명령어를 실행하여 JetStream HTTP 서버에 액세스할 수 있는지 확인합니다.
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
모델 준비로 인해 초기 요청이 완료되는 데 몇 초 정도 걸릴 수 있습니다. 출력은 다음과 비슷하게 표시됩니다.
{
"response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}
분할 추론
분할 서빙은 미리 채우기 및 디코딩 단계를 서로 다른 프로세스(잠재적으로 서로 다른 머신)로 분리하여 대규모 언어 모델(LLM)을 실행하는 기법입니다. 이를 통해 리소스를 더 잘 활용할 수 있으며 특히 대규모 모델의 경우 성능과 효율성이 향상될 수 있습니다.
- 사전 입력: 이 단계에서는 입력 프롬프트를 처리하고 중간 표현 (예: 키-값 캐시)을 생성합니다. 컴퓨팅 집약적인 경우가 많습니다.
- 디코딩: 이 단계에서는 사전 입력 표현을 사용하여 출력 토큰을 하나씩 생성합니다. 일반적으로 메모리 대역폭에 제한됩니다.
이러한 단계를 분리함으로써 분할 서빙을 통해 사전 채우기와 디코딩을 병렬로 실행하여 처리량과 지연 시간을 개선할 수 있습니다.
분리된 서빙을 사용 설정하려면 다음 YAML을 수정하여 두 개의 v6e-8 슬라이스를 활용합니다. 하나는 사전 입력용이고 다른 하나는 생성용입니다. 계속하기 전에 GKE 클러스터에 이 v6e-8 토폴로지로 구성된 노드 풀이 두 개 이상 있는지 확인합니다.
최적의 성능을 위해 특정 XLA 플래그가 구성되었습니다.
이전 섹션에 자세히 설명된 llama3.1-405b 체크포인트 생성과 동일한 프로세스를 따라 llama2-70b 체크포인트를 만듭니다.
- 경로를 사용하여 분리된 모드에서 JetStream 서버를 실행하려면 다음 YAML을
pathways-job.yaml이라는 파일에 복사하여 붙여넣습니다. 다음을 바꿉니다.apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: subGroupPolicy: subGroupSize: 2 leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=2 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4 imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama2-70b. args: - MaxText/configs/base.yml - tokenizer_path=assets/tokenizer.llama2 - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - model_name=llama2-70b - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=1 - ici_tensor_parallelism=-1 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=1 - checkpoint_is_quantized=true - quantization=int8 - quantize_kvcache=true - compute_axis_order=0,2,1,3 - ar_cache_axis_order=0,2,1,3 - stack_prefill_result_cache=True # Specify disaggregated mode to run Jetstream - inference_server=ExperimentalMaxtextDisaggregatedServer_8 - inference_benchmark_test=True - enable_model_warmup=True env: - name: LOG_LEVEL value: "INFO" imagePullPolicy: Always securityContext: capabilities: add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"] ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 240 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE: TPU 가속기 유형입니다. 예를 들면tpu-v6e-slice입니다.TPU_TOPOLOGY: TPU 토폴로지입니다. 예를 들면2x4입니다.GCS_CHECKPOINT_PATH: 체크포인트의 GCS 경로입니다.
- 이 YAML을 적용하면 모델 서버가 체크포인트를 복원하는 데 시간이 걸립니다. 70B 모델의 경우 약 2분이 걸릴 수 있습니다.
kubectl apply -f pathways-job.yaml
- Kubernetes 로그를 확인하여 JetStream 모델 서버가 준비되었는지 확인합니다.
JetStream 모델 서버가 요청을 처리할 준비가 되었음을 나타내는 다음과 비슷한 출력이 표시됩니다.kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
모델 서버에 연결
GKE의 ClusterIP 서비스를 통해 JetStream Pathways 배포에 액세스할 수 있습니다. ClusterIP 서비스는 클러스터 내에서만 연결할 수 있습니다. 따라서 클러스터 외부에서 서비스에 액세스하려면 다음 명령어를 실행하여 포트 전달 세션을 설정하세요.
kubectl port-forward pod/${HEAD_POD} 8000:8000
새 터미널을 열고 다음 명령어를 실행하여 JetStream HTTP 서버에 액세스할 수 있는지 확인합니다.
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
모델 준비로 인해 초기 요청이 완료되는 데 몇 초 정도 걸릴 수 있습니다. 출력은 다음과 비슷하게 표시됩니다.
{
"response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}