L'inférence multihôte est une méthode d'exécution de l'inférence de modèle qui distribue le modèle sur plusieurs hôtes d'accélérateurs. Cela permet l'inférence de grands modèles qui ne tiennent pas sur un seul hôte. Les pathways peuvent être déployés pour les cas d'utilisation d'inférence multihôte par lot et en temps réel.
Avant de commencer
Vérifiez que vous disposez bien des éléments suivants :
- Création d'un cluster GKE utilisant des puces Trillium (v6e-16)
- Outils Kubernetes installés
- Activer l'API TPU
- Activer l'API Google Kubernetes Engine
Exécuter l'inférence par lot à l'aide de JetStream
JetStream est un moteur optimisé pour le débit et la mémoire pour l'inférence de grands modèles de langage (LLM) sur les appareils XLA, principalement les Tensor Processing Units (TPU) écrits en JAX.
Vous pouvez utiliser une image Docker JetStream prédéfinie pour exécuter une charge de travail d'inférence par lot, comme indiqué dans le fichier YAML suivant. Ce conteneur est créé à partir du projet OSS JetStream.
Pour en savoir plus sur les indicateurs MaxText-JetStream, consultez Indicateurs du serveur JetStream MaxText.
L'exemple suivant utilise des puces Trillium (v6e-16) pour charger le point de contrôle Llama3.1-405b int8 et effectuer une inférence sur celui-ci. Dans cet exemple, nous partons du principe que vous disposez déjà d'un cluster GKE avec au moins un pool de nœuds v6e-16.
Démarrer le serveur de modèle et Pathways
- Obtenez les identifiants du cluster et ajoutez-les à votre contexte kubectl local.
gcloud container clusters get-credentials $CLUSTER \ --zone=$ZONE \ --project=$PROJECT \ && kubectl config set-context --current --namespace=default
- Déployez l'API LeaderWorkerSet (LWS).
VERSION=v0.4.0 kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
- Copiez et collez le code YAML suivant dans un fichier nommé
pathways-job.yaml: Ce code YAML a été optimisé pour la forme de tranchev6e-16. Pour savoir comment convertir un point de contrôle Meta en point de contrôle compatible avec JAX, suivez le guide de création de points de contrôle dans Créer des points de contrôle d'inférence. Par exemple, les instructions pour Llama3.1-405B sont fournies ici : Conversion de point de contrôle pour Llama3.1-405B. Remplacez les éléments suivants :apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=1 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4 - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama3.1-405b. args: - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml - model_name=llama3.1-405b - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - async_checkpointing=false - steps=1 - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=2 - ici_tensor_parallelism=8 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=10 - enable_single_controller=true - quantization=int8 - quantize_kvcache=true - checkpoint_is_quantized=true - enable_model_warmup=true imagePullPolicy: Always ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 900 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 10 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE: type d'accélérateur TPU. Exemple :tpu-v6e-sliceTPU_TOPOLOGY: topologie du TPU. Exemple :2x4GCS_CHECKPOINT_PATH: chemin d'accès GCS au point de contrôle.
- Consultez les journaux Kubernetes pour voir si le serveur de modèle JetStream est prêt :
La charge de travail a été nommée "jetstream-pathways" dans le fichier YAML précédent, et "0"
est le nœud principal.
Le résultat ressemble à ce qui suit, ce qui indique que le serveur de modèle JetStream est prêt à traiter les requêtes :kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
Se connecter au serveur de modèle
Vous pouvez accéder au déploiement JetStream Pathways à l'aide du service ClusterIP de GKE. Le service ClusterIP n'est accessible qu'à partir du cluster. Par conséquent, pour accéder au service en dehors du cluster, vous devez d'abord établir une session de transfert de port en exécutant la commande suivante :
kubectl port-forward pod/${HEAD_POD} 8000:8000
Vérifiez que vous pouvez accéder au serveur HTTP JetStream en ouvrant un nouveau terminal et en exécutant la commande suivante :
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
L'exécution de la requête initiale peut prendre plusieurs secondes en raison de l'échauffement du modèle. La sortie devrait ressembler à ce qui suit :
{
"response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}
Inférence désagrégée
Le service désagrégé est une technique permettant d'exécuter de grands modèles de langage (LLM) qui sépare les étapes de préremplissage et de décodage dans différents processus, potentiellement sur différentes machines. Cela permet une meilleure utilisation des ressources et peut améliorer les performances et l'efficacité, en particulier pour les grands modèles.
- Préremplissage : cette étape traite l'invite d'entrée et génère une représentation intermédiaire (comme un cache clé-valeur). Elle nécessite souvent beaucoup de calculs.
- Décodage : cette étape génère les jetons de sortie, un par un, à l'aide de la représentation de préremplissage. Elle est généralement limitée par la bande passante de la mémoire.
En séparant ces étapes, la diffusion désagrégée permet au préremplissage et au décodage de s'exécuter en parallèle, ce qui améliore le débit et la latence.
Pour activer le service désagrégé, modifiez le code YAML suivant afin d'utiliser deux tranches v6e-8 : une pour le préremplissage et l'autre pour la génération. Avant de continuer, assurez-vous que votre cluster GKE comporte au moins deux pools de nœuds configurés avec cette topologie v6e-8.
Pour des performances optimales, des indicateurs XLA spécifiques ont été configurés.
Créez un point de contrôle llama2-70b en suivant le même processus que pour la création du point de contrôle llama3.1-405b, décrit dans la section précédente.
- Pour lancer le serveur JetStream en mode désagrégé à l'aide de Pathways, copiez et collez le code YAML suivant dans un fichier nommé
pathways-job.yaml: Remplacez les éléments suivants :apiVersion: leaderworkerset.x-k8s.io/v1 kind: LeaderWorkerSet metadata: name: jetstream-pathways annotations: leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool spec: replicas: 1 leaderWorkerTemplate: subGroupPolicy: subGroupSize: 2 leaderTemplate: metadata: labels: app: jetstream-pathways spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 tolerations: - key: "google.com/tpu" operator: "Exists" effect: "NoSchedule" containers: - name: pathways-proxy image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest args: - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --server_port=38681 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp imagePullPolicy: Always ports: - containerPort: 38681 - name: pathways-rm env: - name: HOST_ADDRESS value: "$(LWS_LEADER_ADDRESS)" - name: TPU_SKIP_MDS_QUERY value: "true" image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest args: - --server_port=38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp - --node_type=resource_manager - --instance_count=2 - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4 imagePullPolicy: Always ports: - containerPort: 38677 - name: jax-tpu image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0 # Optimized settings used to serve Llama2-70b. args: - MaxText/configs/base.yml - tokenizer_path=assets/tokenizer.llama2 - load_parameters_path=GCS_CHECKPOINT_PATH - max_prefill_predict_length=1024 - max_target_length=2048 - model_name=llama2-70b - ici_fsdp_parallelism=1 - ici_autoregressive_parallelism=1 - ici_tensor_parallelism=-1 - scan_layers=false - weight_dtype=bfloat16 - per_device_batch_size=1 - checkpoint_is_quantized=true - quantization=int8 - quantize_kvcache=true - compute_axis_order=0,2,1,3 - ar_cache_axis_order=0,2,1,3 - stack_prefill_result_cache=True # Specify disaggregated mode to run Jetstream - inference_server=ExperimentalMaxtextDisaggregatedServer_8 - inference_benchmark_test=True - enable_model_warmup=True env: - name: LOG_LEVEL value: "INFO" imagePullPolicy: Always securityContext: capabilities: add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"] ports: - containerPort: 9000 startupProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 1 initialDelaySeconds: 240 failureThreshold: 10000 livenessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 readinessProbe: httpGet: path: /healthcheck port: 8000 scheme: HTTP periodSeconds: 60 failureThreshold: 100 - name: jetstream-http image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3 imagePullPolicy: Always ports: - containerPort: 8000 # The size variable defines the number of worker nodes to be created. # It must be equal to the number of hosts + 1 (for the leader node). size: 5 workerTemplate: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4 containers: - name: worker args: - --server_port=38679 - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677 - --gcs_scratch_location=gs://cloud-pathways-staging/tmp image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest imagePullPolicy: Always ports: - containerPort: 38679 resources: limits: google.com/tpu: "4" --- apiVersion: v1 kind: Service metadata: name: jetstream-svc spec: selector: app: jetstream-pathways ports: - protocol: TCP name: jetstream-http port: 8000 targetPort: 8000
TPU_ACCELERATOR_TYPE: type d'accélérateur TPU. Exemple :tpu-v6e-sliceTPU_TOPOLOGY: topologie du TPU. Exemple :2x4GCS_CHECKPOINT_PATH: chemin d'accès GCS au point de contrôle.
- Appliquez ce fichier YAML. Le serveur de modèle mettra un certain temps à restaurer le point de contrôle. Pour le modèle 70B, cela peut prendre environ deux minutes.
kubectl apply -f pathways-job.yaml
- Consultez les journaux Kubernetes pour voir si le serveur de modèle JetStream est prêt :
Un résultat semblable à celui qui suit s'affiche, indiquant que le serveur de modèle JetStream est prêt à traiter les requêtes :kubectl logs -f jetstream-pathways-0 -c jax-tpu
2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0. 2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0. 2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0. 2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized. ... ... ... INFO: Started server process [7] INFO: Waiting for application startup. INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
Se connecter au serveur de modèle
Vous pouvez accéder au déploiement JetStream Pathways via le service ClusterIP de GKE. Le service ClusterIP n'est accessible qu'à partir du cluster. Par conséquent, pour accéder au service en dehors du cluster, établissez une session de transfert de port en exécutant la commande suivante :
kubectl port-forward pod/${HEAD_POD} 8000:8000
Vérifiez que vous pouvez accéder au serveur HTTP JetStream en ouvrant un nouveau terminal et en exécutant la commande suivante :
curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
L'exécution de la requête initiale peut prendre plusieurs secondes en raison de l'échauffement du modèle. La sortie devrait ressembler à ce qui suit :
{
"response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}
Étapes suivantes
- Charges de travail par lot avec Pathways
- Mode interactif des parcours
- Transférer des charges de travail JAX vers Pathways
- Entraînement résilient avec Pathways
- Parcours de dépannage