Fazer inferência multihost usando o Pathways

A inferência multihost é um método de execução de inferência de modelo que distribui o modelo em vários hosts de aceleradores. Isso permite a inferência de modelos grandes que não cabem em um único host. Os caminhos podem ser implantados para casos de uso de inferência multihost em lote e em tempo real.

Antes de começar

Você precisa ter:

Executar a inferência em lote usando o JetStream

O JetStream é um mecanismo com otimização de capacidade de processamento e memória para inferência de modelos de linguagem grandes (LLMs) em dispositivos XLA, principalmente Unidades de Processamento de Tensor (TPUs) escritas em JAX.

É possível usar uma imagem do Docker do JetStream pré-criada para executar uma carga de trabalho de inferência em lote, conforme mostrado no YAML a seguir. Ele foi criado com base no projeto OSS JetStream. Para mais informações sobre flags do MaxText-JetStream, consulte Flags do servidor JetStream MaxText. O exemplo a seguir usa chips Trillium (v6e-16) para carregar o ponto de verificação int8 do Llama3.1-405b e realizar inferência sobre ele. Este exemplo pressupõe que você já tem um cluster do GKE com pelo menos um pool de nós v6e-16.

Iniciar o servidor de modelo e o Pathways

  1. Receba as credenciais do cluster e adicione-as ao contexto local do kubectl.
          gcloud container clusters get-credentials $CLUSTER \
          --zone=$ZONE \
          --project=$PROJECT \
          && kubectl config set-context --current --namespace=default
        
  2. Implante a API LeaderWorkerSet (LWS).
          VERSION=v0.4.0
          kubectl apply --server-side -f "https://github.com/kubernetes-sigs/lws/releases/download/${VERSION}/manifests.yaml"
        
  3. Copie e cole o seguinte YAML em um arquivo chamado pathways-job.yaml: Esse YAML foi otimizado para o formato de corte v6e-16. Para mais informações sobre como converter um checkpoint do Meta em um checkpoint compatível com JAX, siga o guia de criação de checkpoints em Como criar checkpoints de inferência. Por exemplo, as instruções para o Llama3.1-405B estão disponíveis aqui: Conversão de checkpoint para Llama3.1-405B.
        apiVersion: leaderworkerset.x-k8s.io/v1
        kind: LeaderWorkerSet
        metadata:
          name: jetstream-pathways
          annotations:
            leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
        spec:
          replicas: 1
          leaderWorkerTemplate:
            leaderTemplate:
              metadata:
                labels: 
                  app: jetstream-pathways
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: pathways-proxy
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
                  args:
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --server_port=38681
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38681
                - name: pathways-rm
                  env:       
                  - name: HOST_ADDRESS
                    value: "$(LWS_LEADER_ADDRESS)"
                  - name: TPU_SKIP_MDS_QUERY
                    value: "true"
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  args:
                  - --server_port=38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  - --node_type=resource_manager
                  - --instance_count=1
                  - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 4x4
                  - --temporary_flags_for_debugging=temporary_flag_for_debugging_worker_expected_tpu_chip_config=megachip_tccontrol
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38677
                - name: jax-tpu
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
                  # Optimized settings used to serve Llama3.1-405b.
                  args:
                  - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml
                  - model_name=llama3.1-405b
                  - load_parameters_path=GCS_CHECKPOINT_PATH
                  - max_prefill_predict_length=1024
                  - max_target_length=2048
                  - async_checkpointing=false
                  - steps=1
                  - ici_fsdp_parallelism=1
                  - ici_autoregressive_parallelism=2
                  - ici_tensor_parallelism=8
                  - scan_layers=false
                  - weight_dtype=bfloat16
                  - per_device_batch_size=10
                  - enable_single_controller=true
                  - quantization=int8
                  - quantize_kvcache=true
                  - checkpoint_is_quantized=true
                  - enable_model_warmup=true
                  imagePullPolicy: Always
                  ports: 
                  - containerPort: 9000
                  startupProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 1
                    initialDelaySeconds: 900
                    failureThreshold: 10000
                  livenessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                  readinessProbe:
                    httpGet:
                      path: /healthcheck
                      port: 8000
                      scheme: HTTP
                    periodSeconds: 60
                    failureThreshold: 10
                - name: jetstream-http
                  image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 8000
            # The size variable defines the number of worker nodes to be created.
            # It must be equal to the number of hosts + 1 (for the leader node).
            size: 5
            workerTemplate:
              spec:
                nodeSelector:
                  cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
                  cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 4x4
                tolerations:
                - key: "google.com/tpu"
                  operator: "Exists"
                  effect: "NoSchedule"
                containers:
                - name: worker
                  args:
                  - --server_port=38679
                  - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
                  - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
                  image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
                  imagePullPolicy: Always
                  ports:
                  - containerPort: 38679
                  resources:
                    limits:
                      google.com/tpu: "4"
        --- 
        apiVersion: v1
        kind: Service
        metadata:
          name: jetstream-svc
        spec:
          selector:
            app: jetstream-pathways
          ports:
          - protocol: TCP
            name: jetstream-http
            port: 8000
            targetPort: 8000
        
    Substitua:
    • TPU_ACCELERATOR_TYPE: o tipo de acelerador de TPU. Por exemplo, tpu-v6e-slice.
    • TPU_TOPOLOGY: a topologia da TPU. Por exemplo, 2x4.
    • GCS_CHECKPOINT_PATH: o caminho do GCS para o ponto de verificação.
    Aplique este YAML. Aguarde até que o PathwaysJob seja programado. Depois de programado, o servidor do modelo pode levar algum tempo para restaurar o ponto de verificação. Para o modelo 405B, isso leva cerca de 7 minutos.
  4. Confira os registros do Kubernetes para saber se o servidor de modelo do JetStream está pronto: A carga de trabalho foi chamada de "jetstream-pathways" no YAML anterior, e "0" é o nó principal.
          kubectl logs -f jetstream-pathways-0 -c jax-tpu
          
    A saída é semelhante à seguinte, que indica que o servidor do modelo JetStream está pronto para atender às solicitações:
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
        

Conectar-se ao servidor de modelos

Acesse a implantação do JetStream Pathways usando o serviço ClusterIP do GKE. O serviço ClusterIP só pode ser acessado dentro do cluster. Portanto, para acessar o serviço de fora do cluster, primeiro estabeleça uma sessão de encaminhamento de portas executando o seguinte comando:

kubectl port-forward pod/${HEAD_POD} 8000:8000

Para verificar se é possível acessar o servidor HTTP JetStream, abra um novo terminal e execute o seguinte comando:

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

A solicitação inicial pode levar alguns segundos para ser concluída devido ao aquecimento do modelo. A saída será semelhante a esta:

{
    "response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}

Inferência desagregada

A veiculação desagregada é uma técnica para executar modelos de linguagem grandes (LLMs) que separa as etapas de pré-preenchimento e decodificação em processos diferentes, possivelmente em máquinas diferentes. Isso permite uma melhor utilização dos recursos e pode levar a melhorias no desempenho e na eficiência, especialmente para modelos grandes.

  • Pré-preenchimento: esta etapa processa o comando de entrada e gera uma representação intermediária (como um cache de chave-valor). Ela geralmente exige muito poder de computação.
  • Decodificação: esta etapa gera os tokens de saída, um por um, usando a representação de pré-preenchimento. Normalmente, ele é limitado pela largura de banda da memória.

Ao separar essas etapas, a veiculação desagregada permite que o pré-preenchimento e a decodificação sejam executados em paralelo, melhorando a taxa de transferência e a latência.

Para ativar a veiculação desagregada, modifique o seguinte YAML para usar duas divisões v6e-8: uma para pré-preenchimento e outra para geração. Antes de continuar, verifique se o cluster do GKE tem pelo menos dois pools de nós configurados com essa topologia v6e-8. Para um desempenho ideal, flags específicas do XLA foram configuradas.

Crie um checkpoint do llama2-70b seguindo o mesmo processo de criação do checkpoint do llama3.1-405b, detalhado na seção anterior.

  1. Para iniciar o servidor JetStream no modo desagregado usando o Pathways, copie e cole o seguinte YAML em um arquivo chamado pathways-job.yaml:
    apiVersion: leaderworkerset.x-k8s.io/v1
    kind: LeaderWorkerSet
    metadata:
      name: jetstream-pathways
      annotations:
        leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      replicas: 1
      leaderWorkerTemplate:
        subGroupPolicy:
          subGroupSize: 2
        leaderTemplate:
          metadata:
            labels:
              app: jetstream-pathways
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            tolerations:
            - key: "google.com/tpu"
              operator: "Exists"
              effect: "NoSchedule"
            containers:
            - name: pathways-proxy
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
              args:
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --server_port=38681
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              imagePullPolicy: Always
              ports:
              - containerPort: 38681
            - name: pathways-rm
              env:       
              - name: HOST_ADDRESS
                value: "$(LWS_LEADER_ADDRESS)"
              - name: TPU_SKIP_MDS_QUERY
                value: "true"
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              args:
              - --server_port=38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              - --node_type=resource_manager
              - --instance_count=2
              - --instance_type=tpuv6e:TPU_TOPOLOGY # Example: 2x4
              imagePullPolicy: Always
              ports:
              - containerPort: 38677
            - name: jax-tpu
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
              # Optimized settings used to serve Llama2-70b.
              args:
              - MaxText/configs/base.yml
              - tokenizer_path=assets/tokenizer.llama2
              - load_parameters_path=GCS_CHECKPOINT_PATH
              - max_prefill_predict_length=1024
              - max_target_length=2048
              - model_name=llama2-70b
              - ici_fsdp_parallelism=1
              - ici_autoregressive_parallelism=1
              - ici_tensor_parallelism=-1
              - scan_layers=false
              - weight_dtype=bfloat16
              - per_device_batch_size=1
              - checkpoint_is_quantized=true 
              - quantization=int8
              - quantize_kvcache=true
              - compute_axis_order=0,2,1,3
              - ar_cache_axis_order=0,2,1,3
              - stack_prefill_result_cache=True
              # Specify disaggregated mode to run Jetstream
              - inference_server=ExperimentalMaxtextDisaggregatedServer_8
              - inference_benchmark_test=True
              - enable_model_warmup=True
              env:
              - name: LOG_LEVEL
                value: "INFO"
              imagePullPolicy: Always
              securityContext:
                capabilities:
                  add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"]
              ports: 
              - containerPort: 9000
              startupProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 1
                initialDelaySeconds: 240
                failureThreshold: 10000
              livenessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
              readinessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
            - name: jetstream-http
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
              imagePullPolicy: Always
              ports:
              - containerPort: 8000
        # The size variable defines the number of worker nodes to be created.
        # It must be equal to the number of hosts + 1 (for the leader node).
        size: 5
        workerTemplate:
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: TPU_ACCELERATOR_TYPE # Example: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY # Example: 2x4
            containers:
            - name: worker
              args:
              - --server_port=38679
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
              imagePullPolicy: Always
              ports:
              - containerPort: 38679
              resources:
                limits:
                  google.com/tpu: "4"
    --- 
    apiVersion: v1
    kind: Service
    metadata:
      name: jetstream-svc
    spec:
      selector:
        app: jetstream-pathways
      ports:
      - protocol: TCP
        name: jetstream-http
        port: 8000
        targetPort: 8000
      
    Substitua:
    • TPU_ACCELERATOR_TYPE: o tipo de acelerador de TPU. Por exemplo, tpu-v6e-slice.
    • TPU_TOPOLOGY: a topologia da TPU. Por exemplo, 2x4.
    • GCS_CHECKPOINT_PATH: o caminho do GCS para o ponto de verificação.
  2. Aplique esse YAML. O servidor de modelo vai levar algum tempo para restaurar o checkpoint. Para o modelo de 70 bilhões de parâmetros, isso pode levar cerca de 2 minutos.
      kubectl apply -f pathways-job.yaml
          
  3. Confira os registros do Kubernetes para saber se o servidor de modelo do JetStream está pronto:
        kubectl logs -f jetstream-pathways-0 -c jax-tpu
        
    Você vai ver uma saída semelhante à seguinte, que indica que o servidor do modelo JetStream está pronto para atender às solicitações:
        2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
        2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
        2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
        2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
        ...
        ...
        ...
        INFO:     Started server process [7]
        INFO:     Waiting for application startup.
        INFO:     Application startup complete.
        INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)
      

Conectar-se ao servidor de modelos

É possível acessar a implantação do JetStream Pathways pelo serviço ClusterIP do GKE. O serviço ClusterIP só pode ser acessado de dentro do cluster. Portanto, para acessar o serviço de fora do cluster, estabeleça uma sessão de encaminhamento de portas executando o seguinte comando:

kubectl port-forward pod/${HEAD_POD} 8000:8000

Para verificar se é possível acessar o servidor HTTP JetStream, abra um novo terminal e execute o seguinte comando:

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

A solicitação inicial pode levar alguns segundos para ser concluída devido ao aquecimento do modelo. A saída será semelhante a esta:

{
    "response": " used in software development?\nThe top 5 programming languages used in software development are:\n\n1. Java: Java is a popular programming language used for developing enterprise-level applications, Android apps, and web applications. Its platform independence and ability to run on any device that has a Java Virtual Machine (JVM) installed make it a favorite among developers.\n2. Python: Python is a versatile language that is widely used in software development, data analysis, artificial intelligence, and machine learning. Its simplicity, readability, and ease of use make it a popular choice among developers.\n3. JavaScript: JavaScript is a widely used programming language for web development, allowing developers to create interactive client-side functionality for web applications. It is also used for server-side programming, desktop and mobile application development, and game development.\n4. C++: C++ is a high-performance programming language used for developing operating systems, games, and other high-performance applications."
}

A seguir