Inferenz auf mehreren Hosts mit Pathways durchführen

Die Inferenz mit mehreren Hosts ist eine Methode zum Ausführen der Modellinferenz, bei der das Modell auf mehrere Accelerator-Hosts verteilt wird. So können große Modelle, die nicht auf einen einzelnen Host passen, abgeleitet werden. Pathways können sowohl für Batch- als auch für Echtzeit-Multihost-Inferenzanwendungsfälle bereitgestellt werden.

Hinweise

Sie benötigen Folgendes:

Batchinferenz mit JetStream ausführen

JetStream ist eine auf Durchsatz und Arbeitsspeicher optimierte Engine für die Inferenz großer Sprachmodelle (LLM) auf XLA-Geräten, hauptsächlich Tensor Processing Units (TPUs), die in JAX geschrieben sind.

Sie können ein vorgefertigtes JetStream-Docker-Image verwenden, um eine Batchinferenz-Arbeitslast auszuführen, wie im folgenden YAML-Code gezeigt. Dieser Container basiert auf dem OSS JetStream-Projekt. Weitere Informationen zu MaxText-JetStream-Flags finden Sie unter JetStream-MaxText-Server-Flags. Im folgenden Beispiel werden Trillium-Chips (v6e-16) verwendet, um den Llama3.1-405b-Int8-Prüfpunkt zu laden und darauf Inferenz auszuführen. In diesem Beispiel wird davon ausgegangen, dass Sie bereits einen GKE-Cluster mit mindestens einem v6e-16-Knotenpool haben.

Modellserver und Pathways starten

  1. Rufen Sie die Anmeldedaten für den Cluster ab und fügen Sie sie Ihrem lokalen kubectl-Kontext hinzu.
          gcloud container clusters get-credentials $CLUSTER \
          --zone=$ZONE \
          --project=$PROJECT \
          && kubectl config set-context --current --namespace=default
        
  2. Stellen Sie die LeaderWorkerSet (LWS) API bereit.
          VERSION=v0.4.0
          kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
        
  3. Kopieren Sie den folgenden YAML-Code und fügen Sie ihn in eine Datei namens pathways-job.yaml ein: Dieser YAML-Code wurde für die v6e-16-Scheibenform optimiert. Weitere Informationen zum Konvertieren eines Meta-Prüfpunkts in einen JAX-kompatiblen Prüfpunkt finden Sie in der Anleitung zum Erstellen von Prüfpunkten unter Inferenzprüfpunkte erstellen. Ein Beispiel für die Anleitung für Llama3.1-405B finden Sie hier.
        apiVersion: leaderworkerset.x-k8s.io/v1
        kind: LeaderWorkerSet
        metadata:
          name: jetstream-pathways
          annotations:
            leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
        spec:
          replicas: 1
          leaderWorkerTemplate:
            leaderTemplate:
              metadata:
                labels: 
                  app: jetstream-pathways
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: pathways-proxy
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
                  args:
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --server_port=38681
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38681
                - name: pathways-rm
                  env:       
                  - name: HOST_ADDRESS
                    value: "$(LWS_LEADER_ADDRESS)"
                  - name: TPU_SKIP_MDS_QUERY
                    value: "true"
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  args:
                  - --server_port=38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  - --node_type=resource_manager
                  - --instance_count=1
                  - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4
                  - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38677
                - name: jax-tpu
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
                  # Optimized settings used to serve Llama3.1-405b.
                  args:
                  - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml
                  - model_name=llama3.1-405b
                  - load_parameters_path=GCS_CHECKPOINT_PATH
                  - max_prefill_predict_length=1024
                  - max_target_length=2048
                  - async_checkpointing=false
                  - steps=1
                  - ici_fsdp_parallelism=1
                  - ici_autoregressive_parallelism=2
                  - ici_tensor_parallelism=8
                  - scan_layers=false
                  - weight_dtype=bfloat16
                  - per_device_batch_size=10
                  - enable_single_controller=true
                  - quantization=int8
                  - quantize_kvcache=true
                  - checkpoint_is_quantized=true
                  - enable_model_warmup=true
                  imagePullPolicy: Always
                  ports: 
                  - containerPort: 9000
                  startupProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 1
                    initialDelaySeconds: 900
                    failureThreshold: 10000
                  livenessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                  readinessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                - name: jetstream-http
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 8000
            # The size variable defines the number of worker nodes to be created.
            # It must be equal to the number of hosts + 1 (for the leader node).
            size: 5
            workerTemplate:
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: worker
                  args:
                  - --server_port=38679
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38679
                  resources:
                    limits:
                      google.com/tpu: "4"
        --- 
        apiVersion: v1
        kind: Service
        metadata:
          name: jetstream-svc
        spec:
          selector:
            app: jetstream-pathways
          ports:
          - protocol: TCP
            name: jetstream-http
            port: 8000
            targetPort: 8000
        
    Ersetzen Sie Folgendes:
    • TPU_ACCELERATOR_TYPE: Der TPU-Beschleunigertyp. Beispiel: tpu-v6e-slice.
    • TPU_TOPOLOGY: Die TPU-Topologie. Beispiel: 2x4.
    • GCS_CHECKPOINT_PATH: Der GCS-Pfad zum Prüfpunkt.
    Wenden Sie diese YAML-Datei an. Warten Sie, bis der PathwaysJob geplant ist. Nach der Planung kann es einige Zeit dauern, bis der Modellserver den Prüfpunkt wiederherstellt. Beim Modell 405B dauert das etwa 7 Minuten.
  4. Sehen Sie sich die Kubernetes-Logs an, um festzustellen, ob der JetStream-Modellserver bereit ist: Der Arbeitslastname in der vorherigen YAML-Datei war „jetstream-pathways“ und „0“ ist der Head-Knoten.
          kubectl logs -f jetstream-pathways-0 -c jax-tpu
          
    Die Ausgabe sieht etwa so aus. Das bedeutet, dass der JetStream-Modellserver bereit ist, Anfragen zu bearbeiten:
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
        

Verbindung zum Modellserver herstellen

Sie können über den ClusterIP-Dienst von GKE auf das JetStream Pathways-Deployment zugreifen. Der ClusterIP-Dienst ist nur innerhalb des Clusters erreichbar. Wenn Sie also von außerhalb des Clusters auf den Dienst zugreifen möchten, müssen Sie zuerst eine Portweiterleitungssitzung einrichten. Führen Sie dazu den folgenden Befehl aus:

kubectl port-forward pod/${HEAD_POD} 8000:8000

Prüfen Sie, ob Sie auf den JetStream-HTTP-Server zugreifen können. Öffnen Sie dazu ein neues Terminal und führen Sie den folgenden Befehl aus:

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

Die erste Anfrage kann aufgrund der Aufwärmphase des Modells einige Sekunden dauern. Die Ausgabe sollte in etwa so aussehen:

{
    "response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}

Aufgeschlüsselte Inferenz

Disaggregated Serving ist eine Technik zum Ausführen von Large Language Models (LLMs), bei der die Phasen Prefill und Decode in separate Prozesse aufgeteilt werden, die möglicherweise auf verschiedenen Computern ausgeführt werden. So lassen sich Ressourcen besser nutzen und die Leistung und Effizienz verbessern, insbesondere bei großen Modellen.

  • Vorfüllen: In dieser Phase wird der Eingabe-Prompt verarbeitet und eine Zwischenrepräsentation (z. B. ein Schlüssel/Wert-Cache) generiert. Sie sind oft rechenintensiv.
  • Decodieren: In dieser Phase werden die Ausgabetokens einzeln mithilfe der Prefill-Darstellung generiert. Sie ist in der Regel durch die Speicherbandbreite begrenzt.

Durch die Trennung dieser Phasen können Vorfüllen und Decodieren parallel ausgeführt werden, was den Durchsatz und die Latenz verbessert.

Wenn Sie die disaggregierte Bereitstellung aktivieren möchten, ändern Sie die folgende YAML-Datei so, dass zwei v6e-8-Slices verwendet werden: einer für das Vorabfüllen und der andere für das Generieren. Bevor Sie fortfahren, muss Ihr GKE-Cluster mindestens zwei Knotenpools haben, die mit dieser v6e-8-Topologie konfiguriert sind. Für eine optimale Leistung wurden bestimmte XLA-Flags konfiguriert.

Erstellen Sie einen llama2-70b-Checkpoint nach dem gleichen Verfahren wie beim Erstellen eines llama3.1-405b-Checkpoints, das im vorherigen Abschnitt beschrieben wird.

  1. Wenn Sie den JetStream-Server im disaggregierten Modus mit Pathways starten möchten, kopieren Sie den folgenden YAML-Code und fügen Sie ihn in eine Datei mit dem Namen pathways-job.yaml ein:
    apiVersion: leaderworkerset.x-k8s.io/v1
    kind: LeaderWorkerSet
    metadata:
      name: jetstream-pathways
      annotations:
        leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      replicas: 1
      leaderWorkerTemplate:
        subGroupPolicy:
          subGroupSize: 2
        leaderTemplate:
          metadata:
            labels:
              app: jetstream-pathways
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            tolerations:
            - key: "google.com/tpu"
              operator: "Exists"
              effect: "NoSchedule"
            containers:
            - name: pathways-proxy
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
              args:
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --server_port=38681
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              imagePullPolicy: Always
              ports:
              - containerPort: 38681
            - name: pathways-rm
              env:       
              - name: HOST_ADDRESS
                value: "$(LWS_LEADER_ADDRESS)"
              - name: TPU_SKIP_MDS_QUERY
                value: "true"
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              args:
              - --server_port=38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              - --node_type=resource_manager
              - --instance_count=2
              - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4
              imagePullPolicy: Always
              ports:
              - containerPort: 38677
            - name: jax-tpu
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
              # Optimized settings used to serve Llama2-70b.
              args:
              - MaxText/configs/base.yml
              - tokenizer_path=assets/tokenizer.llama2
              - load_parameters_path=GCS_CHECKPOINT_PATH
              - max_prefill_predict_length=1024
              - max_target_length=2048
              - model_name=llama2-70b
              - ici_fsdp_parallelism=1
              - ici_autoregressive_parallelism=1
              - ici_tensor_parallelism=-1
              - scan_layers=false
              - weight_dtype=bfloat16
              - per_device_batch_size=1
              - checkpoint_is_quantized=true 
              - quantization=int8
              - quantize_kvcache=true
              - compute_axis_order=0,2,1,3
              - ar_cache_axis_order=0,2,1,3
              - stack_prefill_result_cache=True
              # Specify disaggregated mode to run Jetstream
              - inference_server=ExperimentalMaxtextDisaggregatedServer_8
              - inference_benchmark_test=True
              - enable_model_warmup=True
              env:
              - name: LOG_LEVEL
                value: "INFO"
              imagePullPolicy: Always
              securityContext:
                capabilities:
                  add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"]
              ports: 
              - containerPort: 9000
              startupProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 1
                initialDelaySeconds: 240
                failureThreshold: 10000
              livenessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
              readinessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
            - name: jetstream-http
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
              imagePullPolicy: Always
              ports:
              - containerPort: 8000
        # The size variable defines the number of worker nodes to be created.
        # It must be equal to the number of hosts + 1 (for the leader node).
        size: 5
        workerTemplate:
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            containers:
            - name: worker
              args:
              - --server_port=38679
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              imagePullPolicy: Always
              ports:
              - containerPort: 38679
              resources:
                limits:
                  google.com/tpu: "4"
    --- 
    apiVersion: v1
    kind: Service
    metadata:
      name: jetstream-svc
    spec:
      selector:
        app: jetstream-pathways
      ports:
      - protocol: TCP
        name: jetstream-http
        port: 8000
        targetPort: 8000
      
    Ersetzen Sie Folgendes:
    • TPU_ACCELERATOR_TYPE: Der TPU-Beschleunigertyp. Beispiel: tpu-v6e-slice.
    • TPU_TOPOLOGY: Die TPU-Topologie. Beispiel: 2x4.
    • GCS_CHECKPOINT_PATH: Der GCS-Pfad zum Prüfpunkt.
  2. Wenden Sie diese YAML-Datei an. Es dauert einige Zeit, bis der Modellserver den Checkpoint wiederherstellt. Beim 70B-Modell kann dies etwa 2 Minuten dauern.
      kubectl apply -f pathways-job.yaml
          
  3. Sehen Sie sich die Kubernetes-Logs an, um festzustellen, ob der JetStream-Modellserver bereit ist:
        kubectl logs -f jetstream-pathways-0 -c jax-tpu
        
    Es wird eine Ausgabe wie die folgende angezeigt, die angibt, dass der JetStream-Modellserver bereit ist, Anfragen zu bearbeiten:
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
      

Verbindung zum Modellserver herstellen

Sie können über den ClusterIP-Dienst von GKE auf das JetStream Pathways-Deployment zugreifen. Der ClusterIP-Dienst ist nur innerhalb des Clusters erreichbar. Führen Sie daher den folgenden Befehl aus, um eine Portweiterleitungssitzung einzurichten und von außerhalb des Clusters auf den Dienst zuzugreifen:

kubectl port-forward pod/${HEAD_POD} 8000:8000

Prüfen Sie, ob Sie auf den JetStream-HTTP-Server zugreifen können. Öffnen Sie dazu ein neues Terminal und führen Sie den folgenden Befehl aus:

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

Die erste Anfrage kann aufgrund der Aufwärmphase des Modells einige Sekunden dauern. Die Ausgabe sollte in etwa so aussehen:

{
    "response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}

Nächste Schritte