Die Inferenz mit mehreren Hosts ist eine Methode zum Ausführen der Modellinferenz, bei der das Modell auf mehrere Accelerator-Hosts verteilt wird. So können große Modelle, die nicht auf einen einzelnen Host passen, abgeleitet werden. Pathways können sowohl für Batch- als auch für Echtzeit-Multihost-Inferenzanwendungsfälle bereitgestellt werden.
Hinweise
Sie benötigen Folgendes:
- GKE-Cluster mit Trillium-Chips (v6e-16) erstellt.
- Installierte Kubernetes-Tools
- TPU API aktiviert
- Google Kubernetes Engine API aktiviert
Batchinferenz mit JetStream ausführen
JetStream ist eine auf Durchsatz und Arbeitsspeicher optimierte Engine für die Inferenz großer Sprachmodelle (LLM) auf XLA-Geräten, hauptsächlich Tensor Processing Units (TPUs), die in JAX geschrieben sind.
Sie können ein vorgefertigtes JetStream-Docker-Image verwenden, um eine Batchinferenz-Arbeitslast auszuführen, wie im folgenden YAML-Code gezeigt. Dieser Container basiert auf dem OSS JetStream-Projekt.
Weitere Informationen zu MaxText-JetStream-Flags finden Sie unter JetStream-MaxText-Server-Flags.
Im folgenden Beispiel werden Trillium-Chips (v6e-16) verwendet, um den Llama3.1-405b-Int8-Prüfpunkt zu laden und darauf Inferenz auszuführen. In diesem Beispiel wird davon ausgegangen, dass Sie bereits einen GKE-Cluster mit mindestens einem v6e-16-Knotenpool haben.
Modellserver und Pathways starten
- Rufen Sie die Anmeldedaten für den Cluster ab und fügen Sie sie Ihrem lokalen kubectl-Kontext hinzu.
gcloud container clusters get-credentials $CLUSTER \ --zone=$ZONE \ --project=$PROJECT \ && kubectl config set-context --current --namespace=default
- Stellen Sie die LeaderWorkerSet (LWS) API bereit.
VERSION=v0.4.0 kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
- Kopieren Sie den folgenden YAML-Code und fügen Sie ihn in eine Datei namens
pathways-job.yamlein: Dieser YAML-Code wurde für diev6e-16-Scheibenform optimiert. Weitere Informationen zum Konvertieren eines Meta-Prüfpunkts in einen JAX-kompatiblen Prüfpunkt finden Sie in der Anleitung zum Erstellen von Prüfpunkten unter Inferenzprüfpunkte erstellen. Ein Beispiel für die Anleitung für Llama3.1-405B finden Sie hier. Ersetzen Sie Folgendes:apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=1 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4 - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama3.1-405b. args: - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml - model_name=llama3.1-405b - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - async_checkpointing=false - steps=1 - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=2 - ici_tensor_parallelism=8 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=10 - enable_single_controller=true - quantization=int8 - quantize_kvcache=true - checkpoint_is_quantized=true - enable_model_warmup=true imagePullPolicy: Always ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 900 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE: Der TPU-Beschleunigertyp. Beispiel:tpu-v6e-slice.TPU_TOPOLOGY: Die TPU-Topologie. Beispiel:2x4.GCS_CHECKPOINT_PATH: Der GCS-Pfad zum Prüfpunkt.
- Sehen Sie sich die Kubernetes-Logs an, um festzustellen, ob der JetStream-Modellserver bereit ist:
Der Arbeitslastname in der vorherigen YAML-Datei war „jetstream-pathways“ und „0“ ist der Head-Knoten.
Die Ausgabe sieht etwa so aus. Das bedeutet, dass der JetStream-Modellserver bereit ist, Anfragen zu bearbeiten:kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
Verbindung zum Modellserver herstellen
Sie können über den ClusterIP-Dienst von GKE auf das JetStream Pathways-Deployment zugreifen. Der ClusterIP-Dienst ist nur innerhalb des Clusters erreichbar. Wenn Sie also von außerhalb des Clusters auf den Dienst zugreifen möchten, müssen Sie zuerst eine Portweiterleitungssitzung einrichten. Führen Sie dazu den folgenden Befehl aus:
kubectl port-forward pod/${HEAD_POD} 8000:8000
Prüfen Sie, ob Sie auf den JetStream-HTTP-Server zugreifen können. Öffnen Sie dazu ein neues Terminal und führen Sie den folgenden Befehl aus:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
Die erste Anfrage kann aufgrund der Aufwärmphase des Modells einige Sekunden dauern. Die Ausgabe sollte in etwa so aussehen:
{
"response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}
Aufgeschlüsselte Inferenz
Disaggregated Serving ist eine Technik zum Ausführen von Large Language Models (LLMs), bei der die Phasen Prefill und Decode in separate Prozesse aufgeteilt werden, die möglicherweise auf verschiedenen Computern ausgeführt werden. So lassen sich Ressourcen besser nutzen und die Leistung und Effizienz verbessern, insbesondere bei großen Modellen.
- Vorfüllen: In dieser Phase wird der Eingabe-Prompt verarbeitet und eine Zwischenrepräsentation (z. B. ein Schlüssel/Wert-Cache) generiert. Sie sind oft rechenintensiv.
- Decodieren: In dieser Phase werden die Ausgabetokens einzeln mithilfe der Prefill-Darstellung generiert. Sie ist in der Regel durch die Speicherbandbreite begrenzt.
Durch die Trennung dieser Phasen können Vorfüllen und Decodieren parallel ausgeführt werden, was den Durchsatz und die Latenz verbessert.
Wenn Sie die disaggregierte Bereitstellung aktivieren möchten, ändern Sie die folgende YAML-Datei so, dass zwei v6e-8-Slices verwendet werden: einer für das Vorabfüllen und der andere für das Generieren. Bevor Sie fortfahren, muss Ihr GKE-Cluster mindestens zwei Knotenpools haben, die mit dieser v6e-8-Topologie konfiguriert sind.
Für eine optimale Leistung wurden bestimmte XLA-Flags konfiguriert.
Erstellen Sie einen llama2-70b-Checkpoint nach dem gleichen Verfahren wie beim Erstellen eines llama3.1-405b-Checkpoints, das im vorherigen Abschnitt beschrieben wird.
- Wenn Sie den JetStream-Server im disaggregierten Modus mit Pathways starten möchten, kopieren Sie den folgenden YAML-Code und fügen Sie ihn in eine Datei mit dem Namen
pathways-job.yamlein: Ersetzen Sie Folgendes:apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: subGroupPolicy: subGroupSize: 2 leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=2 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4 imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama2-70b. args: - MaxText/configs/base.yml - tokenizer_path=assets/tokenizer.llama2 - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - model_name=llama2-70b - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=1 - ici_tensor_parallelism=-1 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=1 - checkpoint_is_quantized=true - quantization=int8 - quantize_kvcache=true - compute_axis_order=0,2,1,3 - ar_cache_axis_order=0,2,1,3 - stack_prefill_result_cache=True # Specify disaggregated mode to run Jetstream - inference_server=ExperimentalMaxtextDisaggregatedServer_8 - inference_benchmark_test=True - enable_model_warmup=True env: - name: LOG_LEVEL value: "INFO" imagePullPolicy: Always securityContext: capabilities: add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"] ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 240 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE: Der TPU-Beschleunigertyp. Beispiel:tpu-v6e-slice.TPU_TOPOLOGY: Die TPU-Topologie. Beispiel:2x4.GCS_CHECKPOINT_PATH: Der GCS-Pfad zum Prüfpunkt.
- Wenden Sie diese YAML-Datei an. Es dauert einige Zeit, bis der Modellserver den Checkpoint wiederherstellt. Beim 70B-Modell kann dies etwa 2 Minuten dauern.
kubectl apply -f pathways-job.yaml
- Sehen Sie sich die Kubernetes-Logs an, um festzustellen, ob der JetStream-Modellserver bereit ist:
Es wird eine Ausgabe wie die folgende angezeigt, die angibt, dass der JetStream-Modellserver bereit ist, Anfragen zu bearbeiten:kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
Verbindung zum Modellserver herstellen
Sie können über den ClusterIP-Dienst von GKE auf das JetStream Pathways-Deployment zugreifen. Der ClusterIP-Dienst ist nur innerhalb des Clusters erreichbar. Führen Sie daher den folgenden Befehl aus, um eine Portweiterleitungssitzung einzurichten und von außerhalb des Clusters auf den Dienst zuzugreifen:
kubectl port-forward pod/${HEAD_POD} 8000:8000
Prüfen Sie, ob Sie auf den JetStream-HTTP-Server zugreifen können. Öffnen Sie dazu ein neues Terminal und führen Sie den folgenden Befehl aus:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
Die erste Anfrage kann aufgrund der Aufwärmphase des Modells einige Sekunden dauern. Die Ausgabe sollte in etwa so aussehen:
{
"response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}
Nächste Schritte
- Batch-Arbeitslasten mit Pathways
- Interaktiver Modus für Lernpfade
- JAX-Arbeitslasten zu Pathways migrieren
- Belastbares Training mit Pathways
- Pfade zur Fehlerbehebung