L'inferenza multihost è un metodo di esecuzione dell'inferenza del modello che distribuisce il modello su più host di acceleratori. Ciò consente l'inferenza di modelli di grandi dimensioni che non possono essere ospitati su un singolo host. I percorsi possono essere implementati per casi d'uso di inferenza multihost in modalità batch e in tempo reale.
Prima di iniziare
Assicurati di avere:
- È stato creato un cluster GKE che utilizza chip Trillium (v6e-16).
- Strumenti Kubernetes installati
- Abilitato l'API TPU
- Abilitato l'API Google Kubernetes Engine
Esegui l'inferenza batch utilizzando JetStream
JetStream è un motore ottimizzato per la velocità effettiva e la memoria per l'inferenza dei modelli linguistici di grandi dimensioni (LLM) sui dispositivi XLA, principalmente Tensor Processing Unit (TPU) scritti in JAX.
Puoi utilizzare un'immagine Docker JetStream predefinita per eseguire un workload di inferenza batch,
come mostrato nel seguente YAML. Questo container è creato dal
progetto OSS JetStream.
Per saperne di più sui flag MaxText-JetStream, consulta
Flag del server JetStream MaxText.
L'esempio seguente utilizza i chip Trillium (v6e-16) per caricare il checkpoint Llama3.1-405b
int8 ed eseguire l'inferenza. Questo esempio presuppone che tu abbia già
un cluster GKE con almeno un node pool v6e-16 al suo interno.
Avvia il server del modello e Pathways
- Recupera le credenziali per il cluster e aggiungile al contesto kubectl locale.
gcloud container clusters get-credentials $CLUSTER \ --zone=$ZONE \ --project=$PROJECT \ && kubectl config set-context --current --namespace=default
- Esegui il deployment dell'API LeaderWorkerSet (LWS).
VERSION=v0.4.0 kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
- Copia e incolla il seguente codice YAML in un file denominato
pathways-job.yaml: Questo codice YAML è stato ottimizzato per la forma della sezionev6e-16. Per ulteriori informazioni su come convertire un checkpoint Meta in un checkpoint compatibile con JAX, segui la guida alla creazione dei checkpoint in Creazione di checkpoint di inferenza. Ad esempio, le istruzioni per Llama3.1-405B sono disponibili qui Conversione del checkpoint per Llama3.1-405B. Sostituisci quanto segue:apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=1 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4 - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama3.1-405b. args: - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml - model_name=llama3.1-405b - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - async_checkpointing=false - steps=1 - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=2 - ici_tensor_parallelism=8 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=10 - enable_single_controller=true - quantization=int8 - quantize_kvcache=true - checkpoint_is_quantized=true - enable_model_warmup=true imagePullPolicy: Always ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 900 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE: il tipo di acceleratore TPU. Ad esempio,tpu-v6e-slice.TPU_TOPOLOGY: la topologia TPU. Ad esempio,2x4.GCS_CHECKPOINT_PATH: il percorso GCS del checkpoint.
- Esamina i log di Kubernetes per verificare se il server del modello JetStream è pronto:
Il workload è stato denominato `jetstream-pathways` nel file YAML precedente e `0`
è il nodo head.
L'output è simile al seguente, che indica che il server del modello JetStream è pronto a gestire le richieste:kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
Connettiti al server del modello
Puoi accedere al deployment di JetStream Pathways utilizzando il servizio ClusterIP di GKE. Il servizio ClusterIP è raggiungibile solo dall'interno del cluster. Pertanto, per accedere al servizio dall'esterno del cluster, devi prima stabilire una sessione di port forwarding eseguendo il seguente comando:
kubectl port-forward pod/${HEAD_POD} 8000:8000
Verifica di poter accedere al server HTTP JetStream aprendo un nuovo terminale ed eseguendo questo comando:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
Il completamento della richiesta iniziale può richiedere diversi secondi a causa del riscaldamento del modello. L'output dovrebbe essere simile al seguente:
{
"response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}
Inferenza disaggregata
Il servizio disaggregato è una tecnica per l'esecuzione di modelli linguistici di grandi dimensioni (LLM) che separa le fasi di precompilazione e decodifica in processi diversi, potenzialmente su macchine diverse. Ciò consente un migliore utilizzo delle risorse e può portare a miglioramenti in termini di prestazioni ed efficienza, soprattutto per i modelli di grandi dimensioni.
- Precompilazione: in questa fase viene elaborato il prompt di input e viene generata una rappresentazione intermedia (come una cache chiave-valore). Spesso richiede molta potenza di calcolo.
- Decodifica: questa fase genera i token di output, uno alla volta, utilizzando la rappresentazione di precompilazione. In genere è vincolato dalla larghezza di banda della memoria.
Separando queste fasi, la pubblicazione disaggregata consente di eseguire il precompilamento e la decodifica in parallelo, migliorando la velocità effettiva e la latenza.
Per attivare la pubblicazione disaggregata, modifica il seguente file YAML in modo da utilizzare due sezioni v6e-8: una per il prefill e l'altra per la generazione. Prima di procedere, assicurati che il cluster GKE abbia almeno due node pool configurati con questa topologia v6e-8.
Per un rendimento ottimale, sono stati configurati flag XLA specifici.
Crea un checkpoint llama2-70b seguendo la stessa procedura di creazione del checkpoint llama3.1-405b, descritta nella sezione precedente.
- Per avviare il server JetStream in modalità disaggregata utilizzando Pathways, copia
e incolla il seguente YAML in un file denominato
pathways-job.yaml: Sostituisci quanto segue:apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: subGroupPolicy: subGroupSize: 2 leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=2 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4 imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama2-70b. args: - MaxText/configs/base.yml - tokenizer_path=assets/tokenizer.llama2 - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - model_name=llama2-70b - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=1 - ici_tensor_parallelism=-1 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=1 - checkpoint_is_quantized=true - quantization=int8 - quantize_kvcache=true - compute_axis_order=0,2,1,3 - ar_cache_axis_order=0,2,1,3 - stack_prefill_result_cache=True # Specify disaggregated mode to run Jetstream - inference_server=ExperimentalMaxtextDisaggregatedServer_8 - inference_benchmark_test=True - enable_model_warmup=True env: - name: LOG_LEVEL value: "INFO" imagePullPolicy: Always securityContext: capabilities: add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"] ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 240 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE: il tipo di acceleratore TPU. Ad esempio,tpu-v6e-slice.TPU_TOPOLOGY: la topologia TPU. Ad esempio,2x4.GCS_CHECKPOINT_PATH: il percorso GCS del checkpoint.
- Applica questo file YAML. Il server del modello impiegherà un po' di tempo per ripristinare il
checkpoint. Per il modello 70B, l'operazione potrebbe richiedere circa 2 minuti.
kubectl apply -f pathways-job.yaml
- Esamina i log di Kubernetes per verificare se il server del modello JetStream è pronto:
Vedrai un output simile al seguente, che indica che il server del modello JetStream è pronto a gestire le richieste:kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
Connettiti al server del modello
Puoi accedere al deployment di JetStream Pathways tramite il servizio ClusterIP di GKE. Il servizio ClusterIP è raggiungibile solo dall'interno del cluster. Pertanto, per accedere al servizio dall'esterno del cluster, stabilisci una sessione di port forwarding eseguendo questo comando:
kubectl port-forward pod/${HEAD_POD} 8000:8000
Verifica di poter accedere al server HTTP JetStream aprendo un nuovo terminale ed eseguendo questo comando:
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
Il completamento della richiesta iniziale può richiedere diversi secondi a causa del riscaldamento del modello. L'output dovrebbe essere simile al seguente:
{
"response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}
Passaggi successivi
- Carichi di lavoro batch con percorsi
- Modalità interattiva di Pathways
- Portare i carichi di lavoro JAX su Pathways
- Formazione resiliente con Pathways
- Percorsi di risoluzione dei problemi