경로를 사용하여 멀티 호스트 추론 실행

멀티 호스트 추론은 여러 가속기 호스트에 모델을 분산하는 모델 추론을 실행하는 방법입니다. 이를 통해 단일 호스트에 맞지 않는 대규모 모델의 추론이 가능합니다. 경로는 일괄 및 실시간 멀티 호스트 추론 사용 사례에 모두 배포할 수 있습니다.

시작하기 전에

다음 사항이 필요합니다.

JetStream을 사용하여 일괄 추론 실행

JetStream은 JAX로 작성된 XLA 기기(주로 텐서 처리 장치(TPU))에서 대규모 언어 모델(LLM) 추론을 위한 처리량 및 메모리 최적화 엔진입니다.

다음 YAML과 같이 사전 빌드된 JetStream Docker 이미지를 사용하여 일괄 추론 워크로드를 실행할 수 있습니다. 이 컨테이너는 OSS JetStream 프로젝트에서 빌드되었습니다. MaxText-JetStream 플래그에 관한 자세한 내용은 JetStream MaxText 서버 플래그를 참고하세요. 다음 예에서는 Trillium 칩 (v6e-16)을 사용하여 Llama3.1-405b int8 체크포인트를 로드하고 추론을 실행합니다. 이 예시에서는 이미 GKE 클러스터 내에 하나 이상의 v6e-16 노드 풀이 있다고 가정합니다.

모델 서버 및 Pathways 시작

  1. 클러스터의 사용자 인증 정보를 가져와 로컬 kubectl 컨텍스트에 추가합니다.
          gcloud container clusters get-credentials $CLUSTER \
          --zone=$ZONE \
          --project=$PROJECT \
          && kubectl config set-context --current --namespace=default
        
  2. LeaderWorkerSet (LWS) API를 배포합니다.
          VERSION=v0.4.0
          kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
        
  3. 다음 YAML을 pathways-job.yaml이라는 파일에 복사하여 붙여넣습니다. 이 YAML은 v6e-16 슬라이스 모양에 맞게 최적화되었습니다. Meta 체크포인트를 JAX 호환 체크포인트로 변환하는 방법에 관한 자세한 내용은 추론 체크포인트 만들기의 체크포인트 생성 가이드를 참고하세요. 예를 들어 Llama3.1-405B에 관한 안내는 여기 Llama3.1-405B 체크포인트 변환에 제공되어 있습니다.
        apiVersion: leaderworkerset.x-k8s.io/v1
        kind: LeaderWorkerSet
        metadata:
          name: jetstream-pathways
          annotations:
            leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
        spec:
          replicas: 1
          leaderWorkerTemplate:
            leaderTemplate:
              metadata:
                labels: 
                  app: jetstream-pathways
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: pathways-proxy
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
                  args:
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --server_port=38681
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38681
                - name: pathways-rm
                  env:       
                  - name: HOST_ADDRESS
                    value: "$(LWS_LEADER_ADDRESS)"
                  - name: TPU_SKIP_MDS_QUERY
                    value: "true"
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  args:
                  - --server_port=38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  - --node_type=resource_manager
                  - --instance_count=1
                  - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4
                  - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38677
                - name: jax-tpu
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
                  # Optimized settings used to serve Llama3.1-405b.
                  args:
                  - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml
                  - model_name=llama3.1-405b
                  - load_parameters_path=GCS_CHECKPOINT_PATH
                  - max_prefill_predict_length=1024
                  - max_target_length=2048
                  - async_checkpointing=false
                  - steps=1
                  - ici_fsdp_parallelism=1
                  - ici_autoregressive_parallelism=2
                  - ici_tensor_parallelism=8
                  - scan_layers=false
                  - weight_dtype=bfloat16
                  - per_device_batch_size=10
                  - enable_single_controller=true
                  - quantization=int8
                  - quantize_kvcache=true
                  - checkpoint_is_quantized=true
                  - enable_model_warmup=true
                  imagePullPolicy: Always
                  ports: 
                  - containerPort: 9000
                  startupProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 1
                    initialDelaySeconds: 900
                    failureThreshold: 10000
                  livenessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                  readinessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                - name: jetstream-http
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 8000
            # The size variable defines the number of worker nodes to be created.
            # It must be equal to the number of hosts + 1 (for the leader node).
            size: 5
            workerTemplate:
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: worker
                  args:
                  - --server_port=38679
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38679
                  resources:
                    limits:
                      google.com/tpu: "4"
        --- 
        apiVersion: v1
        kind: Service
        metadata:
          name: jetstream-svc
        spec:
          selector:
            app: jetstream-pathways
          ports:
          - protocol: TCP
            name: jetstream-http
            port: 8000
            targetPort: 8000
        
    다음을 바꿉니다.
    • TPU_ACCELERATOR_TYPE: TPU 가속기 유형입니다. 예를 들면 tpu-v6e-slice입니다.
    • TPU_TOPOLOGY: TPU 토폴로지입니다. 예를 들면 2x4입니다.
    • GCS_CHECKPOINT_PATH: 체크포인트의 GCS 경로입니다.
    이 YAML을 적용합니다. PathwaysJob이 예약될 때까지 기다립니다. 예약이 완료되면 모델 서버가 체크포인트를 복원하는 데 시간이 걸릴 수 있습니다. 405B 모델의 경우 약 7분이 걸립니다.
  4. Kubernetes 로그를 확인하여 JetStream 모델 서버가 준비되었는지 확인합니다. 이전 YAML에서 워크로드의 이름은 `jetstream-pathways` 이고 `0`은 헤드 노드입니다.
          kubectl logs -f jetstream-pathways-0 -c jax-tpu
          
    출력은 다음과 비슷하며, 이는 JetStream 모델 서버가 요청을 처리할 준비가 되었음을 나타냅니다.
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
        

모델 서버에 연결

GKE의 ClusterIP 서비스를 사용하여 JetStream Pathways 배포에 액세스할 수 있습니다. ClusterIP 서비스는 클러스터 내에서만 연결할 수 있습니다. 따라서 클러스터 외부에서 서비스에 액세스하려면 먼저 다음 명령어를 실행하여 포트 전달 세션을 설정해야 합니다.

kubectl port-forward pod/${HEAD_POD} 8000:8000

새 터미널을 열고 다음 명령어를 실행하여 JetStream HTTP 서버에 액세스할 수 있는지 확인합니다.

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

모델 준비로 인해 초기 요청이 완료되는 데 몇 초 정도 걸릴 수 있습니다. 출력은 다음과 비슷하게 표시됩니다.

{
    "response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}

분할 추론

분할 서빙미리 채우기디코딩 단계를 서로 다른 프로세스(잠재적으로 서로 다른 머신)로 분리하여 대규모 언어 모델(LLM)을 실행하는 기법입니다. 이를 통해 리소스를 더 잘 활용할 수 있으며 특히 대규모 모델의 경우 성능과 효율성이 향상될 수 있습니다.

  • 사전 입력: 이 단계에서는 입력 프롬프트를 처리하고 중간 표현 (예: 키-값 캐시)을 생성합니다. 컴퓨팅 집약적인 경우가 많습니다.
  • 디코딩: 이 단계에서는 사전 입력 표현을 사용하여 출력 토큰을 하나씩 생성합니다. 일반적으로 메모리 대역폭에 제한됩니다.

이러한 단계를 분리함으로써 분할 서빙을 통해 사전 채우기와 디코딩을 병렬로 실행하여 처리량과 지연 시간을 개선할 수 있습니다.

분리된 서빙을 사용 설정하려면 다음 YAML을 수정하여 두 개의 v6e-8 슬라이스를 활용합니다. 하나는 사전 입력용이고 다른 하나는 생성용입니다. 계속하기 전에 GKE 클러스터에 이 v6e-8 토폴로지로 구성된 노드 풀이 두 개 이상 있는지 확인합니다. 최적의 성능을 위해 특정 XLA 플래그가 구성되었습니다.

이전 섹션에 자세히 설명된 llama3.1-405b 체크포인트 생성과 동일한 프로세스를 따라 llama2-70b 체크포인트를 만듭니다.

  1. 경로를 사용하여 분리된 모드에서 JetStream 서버를 실행하려면 다음 YAML을 pathways-job.yaml이라는 파일에 복사하여 붙여넣습니다.
    apiVersion: leaderworkerset.x-k8s.io/v1
    kind: LeaderWorkerSet
    metadata:
      name: jetstream-pathways
      annotations:
        leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      replicas: 1
      leaderWorkerTemplate:
        subGroupPolicy:
          subGroupSize: 2
        leaderTemplate:
          metadata:
            labels:
              app: jetstream-pathways
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            tolerations:
            - key: "google.com/tpu"
              operator: "Exists"
              effect: "NoSchedule"
            containers:
            - name: pathways-proxy
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
              args:
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --server_port=38681
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              imagePullPolicy: Always
              ports:
              - containerPort: 38681
            - name: pathways-rm
              env:       
              - name: HOST_ADDRESS
                value: "$(LWS_LEADER_ADDRESS)"
              - name: TPU_SKIP_MDS_QUERY
                value: "true"
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              args:
              - --server_port=38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              - --node_type=resource_manager
              - --instance_count=2
              - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4
              imagePullPolicy: Always
              ports:
              - containerPort: 38677
            - name: jax-tpu
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
              # Optimized settings used to serve Llama2-70b.
              args:
              - MaxText/configs/base.yml
              - tokenizer_path=assets/tokenizer.llama2
              - load_parameters_path=GCS_CHECKPOINT_PATH
              - max_prefill_predict_length=1024
              - max_target_length=2048
              - model_name=llama2-70b
              - ici_fsdp_parallelism=1
              - ici_autoregressive_parallelism=1
              - ici_tensor_parallelism=-1
              - scan_layers=false
              - weight_dtype=bfloat16
              - per_device_batch_size=1
              - checkpoint_is_quantized=true 
              - quantization=int8
              - quantize_kvcache=true
              - compute_axis_order=0,2,1,3
              - ar_cache_axis_order=0,2,1,3
              - stack_prefill_result_cache=True
              # Specify disaggregated mode to run Jetstream
              - inference_server=ExperimentalMaxtextDisaggregatedServer_8
              - inference_benchmark_test=True
              - enable_model_warmup=True
              env:
              - name: LOG_LEVEL
                value: "INFO"
              imagePullPolicy: Always
              securityContext:
                capabilities:
                  add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"]
              ports: 
              - containerPort: 9000
              startupProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 1
                initialDelaySeconds: 240
                failureThreshold: 10000
              livenessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
              readinessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
            - name: jetstream-http
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
              imagePullPolicy: Always
              ports:
              - containerPort: 8000
        # The size variable defines the number of worker nodes to be created.
        # It must be equal to the number of hosts + 1 (for the leader node).
        size: 5
        workerTemplate:
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            containers:
            - name: worker
              args:
              - --server_port=38679
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              imagePullPolicy: Always
              ports:
              - containerPort: 38679
              resources:
                limits:
                  google.com/tpu: "4"
    --- 
    apiVersion: v1
    kind: Service
    metadata:
      name: jetstream-svc
    spec:
      selector:
        app: jetstream-pathways
      ports:
      - protocol: TCP
        name: jetstream-http
        port: 8000
        targetPort: 8000
      
    다음을 바꿉니다.
    • TPU_ACCELERATOR_TYPE: TPU 가속기 유형입니다. 예를 들면 tpu-v6e-slice입니다.
    • TPU_TOPOLOGY: TPU 토폴로지입니다. 예를 들면 2x4입니다.
    • GCS_CHECKPOINT_PATH: 체크포인트의 GCS 경로입니다.
  2. 이 YAML을 적용하면 모델 서버가 체크포인트를 복원하는 데 시간이 걸립니다. 70B 모델의 경우 약 2분이 걸릴 수 있습니다.
      kubectl apply -f pathways-job.yaml
          
  3. Kubernetes 로그를 확인하여 JetStream 모델 서버가 준비되었는지 확인합니다.
        kubectl logs -f jetstream-pathways-0 -c jax-tpu
        
    JetStream 모델 서버가 요청을 처리할 준비가 되었음을 나타내는 다음과 비슷한 출력이 표시됩니다.
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
      

모델 서버에 연결

GKE의 ClusterIP 서비스를 통해 JetStream Pathways 배포에 액세스할 수 있습니다. ClusterIP 서비스는 클러스터 내에서만 연결할 수 있습니다. 따라서 클러스터 외부에서 서비스에 액세스하려면 다음 명령어를 실행하여 포트 전달 세션을 설정하세요.

kubectl port-forward pod/${HEAD_POD} 8000:8000

새 터미널을 열고 다음 명령어를 실행하여 JetStream HTTP 서버에 액세스할 수 있는지 확인합니다.

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

모델 준비로 인해 초기 요청이 완료되는 데 몇 초 정도 걸릴 수 있습니다. 출력은 다음과 비슷하게 표시됩니다.

{
    "response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}

다음 단계