Effectuer une inférence multihôte à l'aide de Pathways

L'inférence multihôte est une méthode d'exécution de l'inférence de modèle qui distribue le modèle sur plusieurs hôtes d'accélérateurs. Cela permet l'inférence de grands modèles qui ne tiennent pas sur un seul hôte. Les pathways peuvent être déployés pour les cas d'utilisation d'inférence multihôte par lot et en temps réel.

Avant de commencer

Vérifiez que vous disposez bien des éléments suivants :

Exécuter l'inférence par lot à l'aide de JetStream

JetStream est un moteur optimisé pour le débit et la mémoire pour l'inférence de grands modèles de langage (LLM) sur les appareils XLA, principalement les Tensor Processing Units (TPU) écrits en JAX.

Vous pouvez utiliser une image Docker JetStream prédéfinie pour exécuter une charge de travail d'inférence par lot, comme indiqué dans le fichier YAML suivant. Ce conteneur est créé à partir du projet OSS JetStream. Pour en savoir plus sur les indicateurs MaxText-JetStream, consultez Indicateurs du serveur JetStream MaxText. L'exemple suivant utilise des puces Trillium (v6e-16) pour charger le point de contrôle Llama3.1-405b int8 et effectuer une inférence sur celui-ci. Dans cet exemple, nous partons du principe que vous disposez déjà d'un cluster GKE avec au moins un pool de nœuds v6e-16.

Démarrer le serveur de modèle et Pathways

  1. Obtenez les identifiants du cluster et ajoutez-les à votre contexte kubectl local.
          gcloud container clusters get-credentials $CLUSTER \
          --zone=$ZONE \
          --project=$PROJECT \
          && kubectl config set-context --current --namespace=default
        
  2. Déployez l'API LeaderWorkerSet (LWS).
          VERSION=v0.4.0
          kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
        
  3. Copiez et collez le code YAML suivant dans un fichier nommé pathways-job.yaml : Ce code YAML a été optimisé pour la forme de tranche v6e-16. Pour savoir comment convertir un point de contrôle Meta en point de contrôle compatible avec JAX, suivez le guide de création de points de contrôle dans Créer des points de contrôle d'inférence. Par exemple, les instructions pour Llama3.1-405B sont fournies ici : Conversion de point de contrôle pour Llama3.1-405B.
        apiVersion: leaderworkerset.x-k8s.io/v1
        kind: LeaderWorkerSet
        metadata:
          name: jetstream-pathways
          annotations:
            leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
        spec:
          replicas: 1
          leaderWorkerTemplate:
            leaderTemplate:
              metadata:
                labels: 
                  app: jetstream-pathways
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: pathways-proxy
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
                  args:
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --server_port=38681
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38681
                - name: pathways-rm
                  env:       
                  - name: HOST_ADDRESS
                    value: "$(LWS_LEADER_ADDRESS)"
                  - name: TPU_SKIP_MDS_QUERY
                    value: "true"
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  args:
                  - --server_port=38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  - --node_type=resource_manager
                  - --instance_count=1
                  - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4
                  - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38677
                - name: jax-tpu
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
                  # Optimized settings used to serve Llama3.1-405b.
                  args:
                  - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml
                  - model_name=llama3.1-405b
                  - load_parameters_path=GCS_CHECKPOINT_PATH
                  - max_prefill_predict_length=1024
                  - max_target_length=2048
                  - async_checkpointing=false
                  - steps=1
                  - ici_fsdp_parallelism=1
                  - ici_autoregressive_parallelism=2
                  - ici_tensor_parallelism=8
                  - scan_layers=false
                  - weight_dtype=bfloat16
                  - per_device_batch_size=10
                  - enable_single_controller=true
                  - quantization=int8
                  - quantize_kvcache=true
                  - checkpoint_is_quantized=true
                  - enable_model_warmup=true
                  imagePullPolicy: Always
                  ports: 
                  - containerPort: 9000
                  startupProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 1
                    initialDelaySeconds: 900
                    failureThreshold: 10000
                  livenessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                  readinessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                - name: jetstream-http
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 8000
            # The size variable defines the number of worker nodes to be created.
            # It must be equal to the number of hosts + 1 (for the leader node).
            size: 5
            workerTemplate:
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: worker
                  args:
                  - --server_port=38679
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38679
                  resources:
                    limits:
                      google.com/tpu: "4"
        --- 
        apiVersion: v1
        kind: Service
        metadata:
          name: jetstream-svc
        spec:
          selector:
            app: jetstream-pathways
          ports:
          - protocol: TCP
            name: jetstream-http
            port: 8000
            targetPort: 8000
        
    Remplacez les éléments suivants :
    • TPU_ACCELERATOR_TYPE : type d'accélérateur TPU. Exemple :tpu-v6e-slice
    • TPU_TOPOLOGY : topologie du TPU. Exemple :2x4
    • GCS_CHECKPOINT_PATH : chemin d'accès GCS au point de contrôle.
    Appliquez ce fichier YAML. Attendez que PathwaysJob soit planifié. Une fois la planification effectuée, le serveur de modèle peut mettre un certain temps à restaurer le point de contrôle. Pour le modèle 405B, cela prend environ sept minutes.
  4. Consultez les journaux Kubernetes pour voir si le serveur de modèle JetStream est prêt : La charge de travail a été nommée "jetstream-pathways" dans le fichier YAML précédent, et "0" est le nœud principal.
          kubectl logs -f jetstream-pathways-0 -c jax-tpu
          
    Le résultat ressemble à ce qui suit, ce qui indique que le serveur de modèle JetStream est prêt à traiter les requêtes :
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
        

Se connecter au serveur de modèle

Vous pouvez accéder au déploiement JetStream Pathways à l'aide du service ClusterIP de GKE. Le service ClusterIP n'est accessible qu'à partir du cluster. Par conséquent, pour accéder au service en dehors du cluster, vous devez d'abord établir une session de transfert de port en exécutant la commande suivante :

kubectl port-forward pod/${HEAD_POD} 8000:8000

Vérifiez que vous pouvez accéder au serveur HTTP JetStream en ouvrant un nouveau terminal et en exécutant la commande suivante :

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

L'exécution de la requête initiale peut prendre plusieurs secondes en raison de l'échauffement du modèle. La sortie devrait ressembler à ce qui suit :

{
    "response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}

Inférence désagrégée

Le service désagrégé est une technique permettant d'exécuter de grands modèles de langage (LLM) qui sépare les étapes de préremplissage et de décodage dans différents processus, potentiellement sur différentes machines. Cela permet une meilleure utilisation des ressources et peut améliorer les performances et l'efficacité, en particulier pour les grands modèles.

  • Préremplissage : cette étape traite l'invite d'entrée et génère une représentation intermédiaire (comme un cache clé-valeur). Elle nécessite souvent beaucoup de calculs.
  • Décodage : cette étape génère les jetons de sortie, un par un, à l'aide de la représentation de préremplissage. Elle est généralement limitée par la bande passante de la mémoire.

En séparant ces étapes, la diffusion désagrégée permet au préremplissage et au décodage de s'exécuter en parallèle, ce qui améliore le débit et la latence.

Pour activer le service désagrégé, modifiez le code YAML suivant afin d'utiliser deux tranches v6e-8 : une pour le préremplissage et l'autre pour la génération. Avant de continuer, assurez-vous que votre cluster GKE comporte au moins deux pools de nœuds configurés avec cette topologie v6e-8. Pour des performances optimales, des indicateurs XLA spécifiques ont été configurés.

Créez un point de contrôle llama2-70b en suivant le même processus que pour la création du point de contrôle llama3.1-405b, décrit dans la section précédente.

  1. Pour lancer le serveur JetStream en mode désagrégé à l'aide de Pathways, copiez et collez le code YAML suivant dans un fichier nommé pathways-job.yaml :
    apiVersion: leaderworkerset.x-k8s.io/v1
    kind: LeaderWorkerSet
    metadata:
      name: jetstream-pathways
      annotations:
        leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      replicas: 1
      leaderWorkerTemplate:
        subGroupPolicy:
          subGroupSize: 2
        leaderTemplate:
          metadata:
            labels:
              app: jetstream-pathways
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            tolerations:
            - key: "google.com/tpu"
              operator: "Exists"
              effect: "NoSchedule"
            containers:
            - name: pathways-proxy
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
              args:
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --server_port=38681
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              imagePullPolicy: Always
              ports:
              - containerPort: 38681
            - name: pathways-rm
              env:       
              - name: HOST_ADDRESS
                value: "$(LWS_LEADER_ADDRESS)"
              - name: TPU_SKIP_MDS_QUERY
                value: "true"
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              args:
              - --server_port=38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              - --node_type=resource_manager
              - --instance_count=2
              - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4
              imagePullPolicy: Always
              ports:
              - containerPort: 38677
            - name: jax-tpu
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
              # Optimized settings used to serve Llama2-70b.
              args:
              - MaxText/configs/base.yml
              - tokenizer_path=assets/tokenizer.llama2
              - load_parameters_path=GCS_CHECKPOINT_PATH
              - max_prefill_predict_length=1024
              - max_target_length=2048
              - model_name=llama2-70b
              - ici_fsdp_parallelism=1
              - ici_autoregressive_parallelism=1
              - ici_tensor_parallelism=-1
              - scan_layers=false
              - weight_dtype=bfloat16
              - per_device_batch_size=1
              - checkpoint_is_quantized=true 
              - quantization=int8
              - quantize_kvcache=true
              - compute_axis_order=0,2,1,3
              - ar_cache_axis_order=0,2,1,3
              - stack_prefill_result_cache=True
              # Specify disaggregated mode to run Jetstream
              - inference_server=ExperimentalMaxtextDisaggregatedServer_8
              - inference_benchmark_test=True
              - enable_model_warmup=True
              env:
              - name: LOG_LEVEL
                value: "INFO"
              imagePullPolicy: Always
              securityContext:
                capabilities:
                  add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"]
              ports: 
              - containerPort: 9000
              startupProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 1
                initialDelaySeconds: 240
                failureThreshold: 10000
              livenessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
              readinessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
            - name: jetstream-http
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
              imagePullPolicy: Always
              ports:
              - containerPort: 8000
        # The size variable defines the number of worker nodes to be created.
        # It must be equal to the number of hosts + 1 (for the leader node).
        size: 5
        workerTemplate:
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            containers:
            - name: worker
              args:
              - --server_port=38679
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              imagePullPolicy: Always
              ports:
              - containerPort: 38679
              resources:
                limits:
                  google.com/tpu: "4"
    --- 
    apiVersion: v1
    kind: Service
    metadata:
      name: jetstream-svc
    spec:
      selector:
        app: jetstream-pathways
      ports:
      - protocol: TCP
        name: jetstream-http
        port: 8000
        targetPort: 8000
      
    Remplacez les éléments suivants :
    • TPU_ACCELERATOR_TYPE : type d'accélérateur TPU. Exemple :tpu-v6e-slice
    • TPU_TOPOLOGY : topologie du TPU. Exemple :2x4
    • GCS_CHECKPOINT_PATH : chemin d'accès GCS au point de contrôle.
  2. Appliquez ce fichier YAML. Le serveur de modèle mettra un certain temps à restaurer le point de contrôle. Pour le modèle 70B, cela peut prendre environ deux minutes.
      kubectl apply -f pathways-job.yaml
          
  3. Consultez les journaux Kubernetes pour voir si le serveur de modèle JetStream est prêt :
        kubectl logs -f jetstream-pathways-0 -c jax-tpu
        
    Un résultat semblable à celui qui suit s'affiche, indiquant que le serveur de modèle JetStream est prêt à traiter les requêtes :
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
      

Se connecter au serveur de modèle

Vous pouvez accéder au déploiement JetStream Pathways via le service ClusterIP de GKE. Le service ClusterIP n'est accessible qu'à partir du cluster. Par conséquent, pour accéder au service en dehors du cluster, établissez une session de transfert de port en exécutant la commande suivante :

kubectl port-forward pod/${HEAD_POD} 8000:8000

Vérifiez que vous pouvez accéder au serveur HTTP JetStream en ouvrant un nouveau terminal et en exécutant la commande suivante :

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

L'exécution de la requête initiale peut prendre plusieurs secondes en raison de l'échauffement du modèle. La sortie devrait ressembler à ce qui suit :

{
    "response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}

Étapes suivantes