Pathways を使用してマルチホスト推論を実行する

マルチホスト推論は、モデルを複数のアクセラレータ ホストに分散してモデル推論を実行する方法です。これにより、単一のホストに収まらない大規模なモデルの推論が可能になります。Pathways は、バッチとリアルタイムの両方のマルチホスト推論ユースケースにデプロイできます。

始める前に

インストールに必要なもの:

JetStream を使用してバッチ推論を実行する

JetStream は、XLA デバイス(主に JAX で記述された Tensor Processing Unit(TPU))での大規模言語モデル(LLM)推論用にスループットとメモリが最適化されたエンジンです。

次の YAML に示すように、ビルド済みの JetStream Docker イメージを使用して、バッチ推論ワークロードを実行できます。このコンテナは OSS JetStream プロジェクトからビルドされています。MaxText-JetStream フラグの詳細については、JetStream MaxText サーバーフラグをご覧ください。次の例では、Trillium チップ(v6e-16)を使用して Llama3.1-405b int8 チェックポイントを読み込み、推論を実行します。この例では、少なくとも 1 つの v6e-16 ノードプールを含む GKE クラスタがすでに存在することを前提としています。

モデルサーバーと Pathways を起動する

  1. クラスタの認証情報を取得し、ローカルの kubectl コンテキストに追加します。
          gcloud container clusters get-credentials $CLUSTER \
          --zone=$ZONE \
          --project=$PROJECT \
          && kubectl config set-context --current --namespace=default
        
  2. LeaderWorkerSet(LWS)API をデプロイします。
          VERSION=v0.4.0
          kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
        
  3. 次の YAML をコピーして pathways-job.yaml というファイルに貼り付けます。この YAML は v6e-16 スライス形状用に最適化されています。Meta チェックポイントを JAX 互換のチェックポイントに変換する方法について詳しくは、推論チェックポイントを作成するのチェックポイント作成ガイドをご覧ください。たとえば、Llama3.1-405B の手順については、Llama3.1-405B のチェックポイント変換をご覧ください。
        apiVersion: leaderworkerset.x-k8s.io/v1
        kind: LeaderWorkerSet
        metadata:
          name: jetstream-pathways
          annotations:
            leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
        spec:
          replicas: 1
          leaderWorkerTemplate:
            leaderTemplate:
              metadata:
                labels: 
                  app: jetstream-pathways
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: pathways-proxy
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
                  args:
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --server_port=38681
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38681
                - name: pathways-rm
                  env:       
                  - name: HOST_ADDRESS
                    value: "$(LWS_LEADER_ADDRESS)"
                  - name: TPU_SKIP_MDS_QUERY
                    value: "true"
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  args:
                  - --server_port=38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  - --node_type=resource_manager
                  - --instance_count=1
                  - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4
                  - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38677
                - name: jax-tpu
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
                  # Optimized settings used to serve Llama3.1-405b.
                  args:
                  - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml
                  - model_name=llama3.1-405b
                  - load_parameters_path=GCS_CHECKPOINT_PATH
                  - max_prefill_predict_length=1024
                  - max_target_length=2048
                  - async_checkpointing=false
                  - steps=1
                  - ici_fsdp_parallelism=1
                  - ici_autoregressive_parallelism=2
                  - ici_tensor_parallelism=8
                  - scan_layers=false
                  - weight_dtype=bfloat16
                  - per_device_batch_size=10
                  - enable_single_controller=true
                  - quantization=int8
                  - quantize_kvcache=true
                  - checkpoint_is_quantized=true
                  - enable_model_warmup=true
                  imagePullPolicy: Always
                  ports: 
                  - containerPort: 9000
                  startupProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 1
                    initialDelaySeconds: 900
                    failureThreshold: 10000
                  livenessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                  readinessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                - name: jetstream-http
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 8000
            # The size variable defines the number of worker nodes to be created.
            # It must be equal to the number of hosts + 1 (for the leader node).
            size: 5
            workerTemplate:
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: worker
                  args:
                  - --server_port=38679
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38679
                  resources:
                    limits:
                      google.com/tpu: "4"
        --- 
        apiVersion: v1
        kind: Service
        metadata:
          name: jetstream-svc
        spec:
          selector:
            app: jetstream-pathways
          ports:
          - protocol: TCP
            name: jetstream-http
            port: 8000
            targetPort: 8000
        
    次のように置き換えます。
    • TPU_ACCELERATOR_TYPE: TPU アクセラレータ タイプ。例: tpu-v6e-slice
    • TPU_TOPOLOGY: TPU トポロジ。例: 2x4
    • GCS_CHECKPOINT_PATH: チェックポイントの GCS パス。
    この YAML を適用します。PathwaysJob がスケジュールされるまで待ちます。スケジュールされた後、モデルサーバーがチェックポイントを復元するまでに時間がかかることがあります。405B モデルの場合、これには約 7 分かかります。
  4. Kubernetes ログを確認して、JetStream モデルサーバーの準備ができているかどうかを確認します。前の YAML では、ワークロードの名前は `jetstream-pathways` で、`0` はヘッドノードです。
          kubectl logs -f jetstream-pathways-0 -c jax-tpu
          
    出力は次のようになります。これは、JetStream モデルサーバーがリクエストを処理する準備ができていることを示しています。
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
        

モデルサーバーに接続する

JetStream Pathways Deployment には、GKE の ClusterIP Service を使用してアクセスできます。ClusterIP Service にはクラスタ内からのみアクセスできます。したがって、クラスタの外部からサービスにアクセスするには、まず次のコマンドを実行してポート転送セッションを確立する必要があります。

kubectl port-forward pod/${HEAD_POD} 8000:8000

新しいターミナルを開いて次のコマンドを実行し、JetStream HTTP サーバーにアクセスできることを確認します。

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

モデルのウォームアップにより、最初のリクエストが完了するまでに数秒かかることがあります。出力例を以下に示します。

{
    "response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}

分離型推論

分離型サービングは、大規模言語モデル(LLM)を実行する手法です。この手法では、プレフィル ステージとデコード ステージを異なるプロセスに分離します。これらのプロセスは、異なるマシンで実行される可能性があります。その結果、リソースの利用率が向上し、特に大規模なモデルのパフォーマンスと効率が改善されます。

  • プリフィル: このステージでは、入力プロンプトを処理し、中間表現(Key-Value キャッシュなど)を生成します。多くの場合、コンピューティング負荷が高くなります。
  • デコード: このステージでは、プリフィル表現を使用して出力トークンを 1 つずつ生成します。通常はメモリ帯域幅に制限されます。

これらのステージを分離することで、分離型サービングではプレフィルとデコードを並行して実行できるため、スループットとレイテンシが向上します。

分離型サービングを有効にするには、次の YAML を変更して、2 つの v6e-8 スライス(1 つはプリフィル用、もう 1 つは生成用)を利用します。続行する前に、GKE クラスタにこの v6e-8 トポロジで構成されたノードプールが少なくとも 2 つあることを確認してください。最適なパフォーマンスを実現するために、特定の XLA フラグが構成されています。

前のセクションで説明した llama3.1-405b チェックポイントの作成と同じプロセスで、llama2-70b チェックポイントを作成します。

  1. Pathways を使用して JetStream サーバーを分離モードで起動するには、次の YAML をコピーして pathways-job.yaml という名前のファイルに貼り付けます。
    apiVersion: leaderworkerset.x-k8s.io/v1
    kind: LeaderWorkerSet
    metadata:
      name: jetstream-pathways
      annotations:
        leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      replicas: 1
      leaderWorkerTemplate:
        subGroupPolicy:
          subGroupSize: 2
        leaderTemplate:
          metadata:
            labels:
              app: jetstream-pathways
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            tolerations:
            - key: "google.com/tpu"
              operator: "Exists"
              effect: "NoSchedule"
            containers:
            - name: pathways-proxy
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
              args:
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --server_port=38681
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              imagePullPolicy: Always
              ports:
              - containerPort: 38681
            - name: pathways-rm
              env:       
              - name: HOST_ADDRESS
                value: "$(LWS_LEADER_ADDRESS)"
              - name: TPU_SKIP_MDS_QUERY
                value: "true"
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              args:
              - --server_port=38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              - --node_type=resource_manager
              - --instance_count=2
              - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4
              imagePullPolicy: Always
              ports:
              - containerPort: 38677
            - name: jax-tpu
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
              # Optimized settings used to serve Llama2-70b.
              args:
              - MaxText/configs/base.yml
              - tokenizer_path=assets/tokenizer.llama2
              - load_parameters_path=GCS_CHECKPOINT_PATH
              - max_prefill_predict_length=1024
              - max_target_length=2048
              - model_name=llama2-70b
              - ici_fsdp_parallelism=1
              - ici_autoregressive_parallelism=1
              - ici_tensor_parallelism=-1
              - scan_layers=false
              - weight_dtype=bfloat16
              - per_device_batch_size=1
              - checkpoint_is_quantized=true 
              - quantization=int8
              - quantize_kvcache=true
              - compute_axis_order=0,2,1,3
              - ar_cache_axis_order=0,2,1,3
              - stack_prefill_result_cache=True
              # Specify disaggregated mode to run Jetstream
              - inference_server=ExperimentalMaxtextDisaggregatedServer_8
              - inference_benchmark_test=True
              - enable_model_warmup=True
              env:
              - name: LOG_LEVEL
                value: "INFO"
              imagePullPolicy: Always
              securityContext:
                capabilities:
                  add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"]
              ports: 
              - containerPort: 9000
              startupProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 1
                initialDelaySeconds: 240
                failureThreshold: 10000
              livenessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
              readinessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
            - name: jetstream-http
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
              imagePullPolicy: Always
              ports:
              - containerPort: 8000
        # The size variable defines the number of worker nodes to be created.
        # It must be equal to the number of hosts + 1 (for the leader node).
        size: 5
        workerTemplate:
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            containers:
            - name: worker
              args:
              - --server_port=38679
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              imagePullPolicy: Always
              ports:
              - containerPort: 38679
              resources:
                limits:
                  google.com/tpu: "4"
    --- 
    apiVersion: v1
    kind: Service
    metadata:
      name: jetstream-svc
    spec:
      selector:
        app: jetstream-pathways
      ports:
      - protocol: TCP
        name: jetstream-http
        port: 8000
        targetPort: 8000
      
    次のように置き換えます。
    • TPU_ACCELERATOR_TYPE: TPU アクセラレータ タイプ。例: tpu-v6e-slice
    • TPU_TOPOLOGY: TPU トポロジ。例: 2x4
    • GCS_CHECKPOINT_PATH: チェックポイントの GCS パス。
  2. この YAML を適用すると、モデルサーバーがチェックポイントを復元するまでに時間がかかります。70B モデルの場合、これには約 2 分かかることがあります。
      kubectl apply -f pathways-job.yaml
          
  3. Kubernetes ログを確認して、JetStream モデルサーバーの準備が整っているかどうかを確認します。
        kubectl logs -f jetstream-pathways-0 -c jax-tpu
        
    次のような出力が表示されます。これは、JetStream モデルサーバーがリクエストを処理する準備ができていることを示します。
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
      

モデルサーバーに接続する

JetStream Pathways Deployment には、GKE の ClusterIP サービスを介してアクセスできます。ClusterIP Service にはクラスタ内からのみアクセスできます。したがって、クラスタの外部から Service にアクセスするには、次のコマンドを実行してポート転送セッションを確立します。

kubectl port-forward pod/${HEAD_POD} 8000:8000

新しいターミナルを開いて次のコマンドを実行し、JetStream HTTP サーバーにアクセスできることを確認します。

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

モデルのウォームアップにより、最初のリクエストが完了するまでに数秒かかることがあります。出力例を以下に示します。

{
    "response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}

次のステップ