Realiza la inferencia multihost con Pathways

La inferencia en varios hosts es un método para ejecutar la inferencia del modelo que distribuye el modelo en varios hosts de aceleradores. Esto permite la inferencia de modelos grandes que no caben en un solo host. Las rutas de acceso se pueden implementar para casos de uso de inferencia multihost por lotes y en tiempo real.

Antes de comenzar

Asegúrate de tener lo siguiente:

Ejecuta la inferencia por lotes con JetStream

JetStream es un motor optimizado para la capacidad de procesamiento y la memoria para la inferencia de modelos de lenguaje grandes (LLM) en dispositivos XLA, principalmente unidades de procesamiento tensorial (TPU) escritas en JAX.

Puedes usar una imagen de Docker de JetStream compilada previamente para ejecutar una carga de trabajo de inferencia por lotes, como se muestra en el siguiente YAML. Este contenedor se compila a partir del proyecto de OSS JetStream. Para obtener más información sobre las marcas de MaxText-JetStream, consulta Marcas del servidor de JetStream MaxText. En el siguiente ejemplo, se usan chips Trillium (v6e-16) para cargar el punto de control int8 de Llama3.1-405b y realizar la inferencia sobre él. En este ejemplo, se supone que ya tienes un clúster de GKE con al menos un grupo de nodos v6e-16.

Inicia el servidor de modelos y Pathways

  1. Obtén credenciales para el clúster y agrégalas a tu contexto de kubectl local.
          gcloud container clusters get-credentials $CLUSTER \
          --zone=$ZONE \
          --project=$PROJECT \
          && kubectl config set-context --current --namespace=default
        
  2. Implementa la API de LeaderWorkerSet (LWS).
          VERSION=v0.4.0
          kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
        
  3. Copia y pega el siguiente YAML en un archivo llamado pathways-job.yaml: Este YAML se optimizó para la forma de corte v6e-16. Para obtener más información sobre cómo convertir un punto de control de Meta en uno compatible con JAX, sigue la guía de creación de puntos de control en Cómo crear puntos de control de inferencia. Como ejemplo, aquí se proporcionan instrucciones para Llama3.1-405B Checkpoint conversion for Llama3.1-405B.
        apiVersion: leaderworkerset.x-k8s.io/v1
        kind: LeaderWorkerSet
        metadata:
          name: jetstream-pathways
          annotations:
            leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
        spec:
          replicas: 1
          leaderWorkerTemplate:
            leaderTemplate:
              metadata:
                labels: 
                  app: jetstream-pathways
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: pathways-proxy
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
                  args:
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --server_port=38681
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38681
                - name: pathways-rm
                  env:       
                  - name: HOST_ADDRESS
                    value: "$(LWS_LEADER_ADDRESS)"
                  - name: TPU_SKIP_MDS_QUERY
                    value: "true"
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  args:
                  - --server_port=38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  - --node_type=resource_manager
                  - --instance_count=1
                  - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4
                  - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38677
                - name: jax-tpu
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
                  # Optimized settings used to serve Llama3.1-405b.
                  args:
                  - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml
                  - model_name=llama3.1-405b
                  - load_parameters_path=GCS_CHECKPOINT_PATH
                  - max_prefill_predict_length=1024
                  - max_target_length=2048
                  - async_checkpointing=false
                  - steps=1
                  - ici_fsdp_parallelism=1
                  - ici_autoregressive_parallelism=2
                  - ici_tensor_parallelism=8
                  - scan_layers=false
                  - weight_dtype=bfloat16
                  - per_device_batch_size=10
                  - enable_single_controller=true
                  - quantization=int8
                  - quantize_kvcache=true
                  - checkpoint_is_quantized=true
                  - enable_model_warmup=true
                  imagePullPolicy: Always
                  ports: 
                  - containerPort: 9000
                  startupProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 1
                    initialDelaySeconds: 900
                    failureThreshold: 10000
                  livenessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                  readinessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                - name: jetstream-http
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 8000
            # The size variable defines the number of worker nodes to be created.
            # It must be equal to the number of hosts + 1 (for the leader node).
            size: 5
            workerTemplate:
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: worker
                  args:
                  - --server_port=38679
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38679
                  resources:
                    limits:
                      google.com/tpu: "4"
        --- 
        apiVersion: v1
        kind: Service
        metadata:
          name: jetstream-svc
        spec:
          selector:
            app: jetstream-pathways
          ports:
          - protocol: TCP
            name: jetstream-http
            port: 8000
            targetPort: 8000
        
    Reemplaza lo siguiente:
    • TPU_ACCELERATOR_TYPE: Es el tipo de acelerador de TPU. Por ejemplo, tpu-v6e-slice
    • TPU_TOPOLOGY: Es la topología de la TPU. Por ejemplo, 2x4
    • GCS_CHECKPOINT_PATH: Es la ruta de acceso de GCS al punto de control.
    Aplica este archivo YAML. Espera a que se programe el trabajo de Pathways. Una vez programado, es posible que el servidor del modelo tarde un tiempo en restablecer el punto de control. En el caso del modelo 405B, esto tarda alrededor de 7 minutos.
  4. Consulta los registros de Kubernetes para ver si el servidor del modelo de JetStream está listo: La carga de trabajo se llamó "jetstream-pathways" en el YAML anterior, y "0" es el nodo principal.
          kubectl logs -f jetstream-pathways-0 -c jax-tpu
          
    El resultado es similar al siguiente, que indica que el servidor del modelo de JetStream está listo para atender solicitudes:
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
        

Conéctate al servidor del modelo

Puedes acceder a la implementación de JetStream Pathways con el servicio ClusterIP de GKE. Solo se puede acceder al servicio de ClusterIP desde el clúster. Por lo tanto, para acceder al servicio desde fuera del clúster, primero debes establecer una sesión de redirección de puertos ejecutando el siguiente comando:

kubectl port-forward pod/${HEAD_POD} 8000:8000

Para comprobar que puedes acceder al servidor HTTP de JetStream, abre una terminal nueva y ejecuta el siguiente comando:

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

La solicitud inicial puede tardar varios segundos en completarse debido a la preparación del modelo. El resultado debería ser similar al siguiente ejemplo:

{
    "response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}

Inferencia desagregada

La publicación desagregada es una técnica para ejecutar modelos de lenguaje grandes (LLM) que separa las etapas de prefill y decodificación en diferentes procesos, posiblemente en diferentes máquinas. Esto permite un mejor uso de los recursos y puede generar mejoras en el rendimiento y la eficiencia, en especial para los modelos grandes.

  • Prefill: En esta etapa, se procesa la instrucción de entrada y se genera una representación intermedia (como una caché de pares clave-valor). A menudo, requiere mucha capacidad de procesamiento.
  • Decodificación: En esta etapa, se generan los tokens de salida, uno por uno, con la representación de precompletado. Por lo general, está limitado por el ancho de banda de la memoria.

Al separar estas etapas, la entrega desagregada permite que el precompletado y la decodificación se ejecuten en paralelo, lo que mejora la capacidad de procesamiento y la latencia.

Para habilitar la publicación desagregada, modifica el siguiente código YAML para utilizar dos segmentos de v6e-8: uno para el prellenado y otro para la generación. Antes de continuar, asegúrate de que tu clúster de GKE tenga al menos dos grupos de nodos configurados con esta topología v6e-8. Para un rendimiento óptimo, se configuraron indicadores específicos de XLA.

Crea un punto de control de llama2-70b siguiendo el mismo proceso que el de creación del punto de control de llama3.1-405b, que se detalla en la sección anterior.

  1. Para iniciar el servidor de JetStream en modo desagregado con Pathways, copia y pega el siguiente código YAML en un archivo llamado pathways-job.yaml:
    apiVersion: leaderworkerset.x-k8s.io/v1
    kind: LeaderWorkerSet
    metadata:
      name: jetstream-pathways
      annotations:
        leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      replicas: 1
      leaderWorkerTemplate:
        subGroupPolicy:
          subGroupSize: 2
        leaderTemplate:
          metadata:
            labels:
              app: jetstream-pathways
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            tolerations:
            - key: "google.com/tpu"
              operator: "Exists"
              effect: "NoSchedule"
            containers:
            - name: pathways-proxy
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
              args:
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --server_port=38681
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              imagePullPolicy: Always
              ports:
              - containerPort: 38681
            - name: pathways-rm
              env:       
              - name: HOST_ADDRESS
                value: "$(LWS_LEADER_ADDRESS)"
              - name: TPU_SKIP_MDS_QUERY
                value: "true"
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              args:
              - --server_port=38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              - --node_type=resource_manager
              - --instance_count=2
              - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4
              imagePullPolicy: Always
              ports:
              - containerPort: 38677
            - name: jax-tpu
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
              # Optimized settings used to serve Llama2-70b.
              args:
              - MaxText/configs/base.yml
              - tokenizer_path=assets/tokenizer.llama2
              - load_parameters_path=GCS_CHECKPOINT_PATH
              - max_prefill_predict_length=1024
              - max_target_length=2048
              - model_name=llama2-70b
              - ici_fsdp_parallelism=1
              - ici_autoregressive_parallelism=1
              - ici_tensor_parallelism=-1
              - scan_layers=false
              - weight_dtype=bfloat16
              - per_device_batch_size=1
              - checkpoint_is_quantized=true 
              - quantization=int8
              - quantize_kvcache=true
              - compute_axis_order=0,2,1,3
              - ar_cache_axis_order=0,2,1,3
              - stack_prefill_result_cache=True
              # Specify disaggregated mode to run Jetstream
              - inference_server=ExperimentalMaxtextDisaggregatedServer_8
              - inference_benchmark_test=True
              - enable_model_warmup=True
              env:
              - name: LOG_LEVEL
                value: "INFO"
              imagePullPolicy: Always
              securityContext:
                capabilities:
                  add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"]
              ports: 
              - containerPort: 9000
              startupProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 1
                initialDelaySeconds: 240
                failureThreshold: 10000
              livenessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
              readinessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
            - name: jetstream-http
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
              imagePullPolicy: Always
              ports:
              - containerPort: 8000
        # The size variable defines the number of worker nodes to be created.
        # It must be equal to the number of hosts + 1 (for the leader node).
        size: 5
        workerTemplate:
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            containers:
            - name: worker
              args:
              - --server_port=38679
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              imagePullPolicy: Always
              ports:
              - containerPort: 38679
              resources:
                limits:
                  google.com/tpu: "4"
    --- 
    apiVersion: v1
    kind: Service
    metadata:
      name: jetstream-svc
    spec:
      selector:
        app: jetstream-pathways
      ports:
      - protocol: TCP
        name: jetstream-http
        port: 8000
        targetPort: 8000
      
    Reemplaza lo siguiente:
    • TPU_ACCELERATOR_TYPE: Es el tipo de acelerador de TPU. Por ejemplo, tpu-v6e-slice
    • TPU_TOPOLOGY: Es la topología de la TPU. Por ejemplo, 2x4
    • GCS_CHECKPOINT_PATH: Es la ruta de acceso de GCS al punto de control.
  2. Aplica este archivo YAML. El servidor del modelo tardará un tiempo en restablecer el punto de control. En el caso del modelo de 70B, este proceso puede tardar alrededor de 2 minutos.
      kubectl apply -f pathways-job.yaml
          
  3. Consulta los registros de Kubernetes para ver si el servidor del modelo de JetStream está listo:
        kubectl logs -f jetstream-pathways-0 -c jax-tpu
        
    Verás un resultado similar al siguiente, que indica que el servidor del modelo de JetStream está listo para atender solicitudes:
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
      

Conéctate al servidor del modelo

Puedes acceder a la implementación de JetStream Pathways a través del servicio ClusterIP de GKE. Solo se puede acceder al servicio ClusterIP desde el clúster. Por lo tanto, para acceder al servicio desde fuera del clúster, establece una sesión de redirección de puertos ejecutando el siguiente comando:

kubectl port-forward pod/${HEAD_POD} 8000:8000

Para comprobar que puedes acceder al servidor HTTP de JetStream, abre una terminal nueva y ejecuta el siguiente comando:

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

La solicitud inicial puede tardar varios segundos en completarse debido a la preparación del modelo. El resultado debería ser similar al siguiente ejemplo:

{
    "response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}

¿Qué sigue?