Servir un LLM mediante TPUs de varios hosts en GKE con Saxml

En este tutorial se muestra cómo desplegar y servir un modelo de lenguaje grande (LLM) mediante un grupo de nodos de segmento de TPU de varios hosts en Google Kubernetes Engine (GKE) con Saxml para crear una arquitectura escalable y eficiente.

Fondo

Saxml es un sistema experimental que sirve los frameworks Paxml, JAX y PyTorch. Puedes usar las TPUs para acelerar el procesamiento de datos con estos frameworks. Para mostrar la implementación de TPUs en GKE, en este tutorial se usa el modelo de prueba LmCloudSpmd175B32Test de 175 000 millones. GKE despliega este modelo de prueba en dos grupos de nodos de segmentos de TPU v5e con topologías 4x8 respectivamente.

Para desplegar correctamente el modelo de prueba, la topología de la TPU se ha definido en función del tamaño del modelo. Dado que el modelo de N mil millones de 16 bits requiere aproximadamente 2 veces (2 × N) GB de memoria, el modelo 175B LmCloudSpmd175B32Test requiere unos 350 GB de memoria. El chip de TPU único de la versión 5e de TPU tiene 16 GB. Para admitir 350 GB, GKE necesita 21 chips de TPU v5e (350/16= 21). Según la asignación de la configuración de TPU, la configuración de TPU adecuada para este tutorial es la siguiente:

  • Tipo de máquina: ct5lp-hightpu-4t
  • Topología: 4x8 (32 chips de TPU)

Seleccionar la topología de TPU adecuada para servir un modelo es importante al desplegar TPUs en GKE. Para obtener más información, consulta Planificar la configuración de las TPU.

Preparar el entorno

  1. En la Google Cloud consola, inicia una instancia de Cloud Shell:
    Abrir Cloud Shell

  2. Define las variables de entorno predeterminadas:

      gcloud config set project PROJECT_ID
      export PROJECT_ID=$(gcloud config get project)
      export CONTROL_PLANE_LOCATION=CONTROL_PLANE_LOCATION
      export BUCKET_NAME=PROJECT_ID-gke-bucket
    

    Sustituye los siguientes valores:

    • PROJECT_ID: tu Google Cloud ID de proyecto.
    • CONTROL_PLANE_LOCATION: la zona de Compute Engine del plano de control de tu clúster. Selecciona la zona en la que está disponible ct5lp-hightpu-4t.

    En este comando, BUCKET_NAME especifica el nombre del Google Cloud segmento de almacenamiento en el que se almacenarán las configuraciones del servidor de administrador de Saxml.

Crear un clúster estándar de GKE

Usa Cloud Shell para hacer lo siguiente:

  1. Crea un clúster estándar que use Workload Identity Federation para GKE:

    gcloud container clusters create saxml \
        --location=${CONTROL_PLANE_LOCATION} \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --cluster-version=VERSION \
        --num-nodes=4
    

    Sustituye VERSION por el número de versión de GKE. GKE admite TPU v5e en la versión 1.27.2-gke.2100 y posteriores. Para obtener más información, consulta Disponibilidad de las TPU en GKE.

    La creación del clúster puede tardar varios minutos.

  2. Crea el primer grupo de nodos llamado tpu1:

    gcloud container node-pools create tpu1 \
        --location=${CONTROL_PLANE_LOCATION} \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=4x8 \
        --num-nodes=8 \
        --cluster=saxml
    

    El valor de la marca --num-nodes se calcula dividiendo la topología de la TPU entre el número de chips de TPU por segmento de TPU. En este caso, sería (4 * 8) / 4.

  3. Crea el segundo grupo de nodos llamado tpu2:

    gcloud container node-pools create tpu2 \
        --location=${CONTROL_PLANE_LOCATION} \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=4x8 \
        --num-nodes=8 \
        --cluster=saxml
    

    El valor de la marca --num-nodes se calcula dividiendo la topología de la TPU entre el número de chips de TPU por segmento de TPU. En este caso, sería (4 * 8) / 4.

Has creado los siguientes recursos:

  • Un clúster estándar con cuatro nodos de CPU.
  • Dos grupos de nodos de segmentos de TPU v5e con topología 4x8. Cada grupo de nodos representa ocho nodos de segmento de TPU con cuatro chips de TPU cada uno.

El modelo de 175.000 millones de parámetros se debe servir en un slice de TPU v5e de varios hosts con un 4x8slice de topología (32 chips de TPU v5e) como mínimo.

Crea un segmento de Cloud Storage

Crea un segmento de Cloud Storage para almacenar las configuraciones del servidor de administrador de Saxml. Un servidor de administrador en ejecución guarda periódicamente su estado y los detalles de los modelos publicados.

En Cloud Shell, ejecuta lo siguiente:

gcloud storage buckets create gs://${BUCKET_NAME}

Configurar el acceso de las cargas de trabajo mediante Workload Identity Federation para GKE

Asigna una cuenta de servicio de Kubernetes a la aplicación y configura esa cuenta de servicio de Kubernetes para que actúe como cuenta de servicio de gestión de identidades y accesos.

  1. Configura kubectl para que se comunique con tu clúster:

    gcloud container clusters get-credentials saxml --location=${CONTROL_PLANE_LOCATION}
    
  2. Crea una cuenta de servicio de Kubernetes para que la use tu aplicación:

    kubectl create serviceaccount sax-sa --namespace default
    
  3. Crea una cuenta de servicio de IAM para tu aplicación:

    gcloud iam service-accounts create sax-iam-sa
    
  4. Añade un enlace de política de gestión de identidades y accesos (IAM) a tu cuenta de servicio de IAM para leer y escribir en Cloud Storage:

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
      --member "serviceAccount:sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com" \
      --role roles/storage.admin
    
  5. Permite que la cuenta de servicio de Kubernetes suplante la identidad de la cuenta de servicio de gestión de identidades y accesos añadiendo un enlace de política de gestión de identidades y accesos entre las dos cuentas de servicio. Esta vinculación permite que la cuenta de servicio de Kubernetes actúe como la cuenta de servicio de IAM, de modo que la cuenta de servicio de Kubernetes pueda leer y escribir en Cloud Storage.

    gcloud iam service-accounts add-iam-policy-binding sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com \
      --role roles/iam.workloadIdentityUser \
      --member "serviceAccount:${PROJECT_ID}.svc.id.goog[default/sax-sa]"
    
  6. Anota la cuenta de servicio de Kubernetes con la dirección de correo de la cuenta de servicio de gestión de identidades y accesos. De esta forma, tu aplicación de muestra sabrá qué cuenta de servicio debe usar para acceder a los servicios de Google Cloud . Por lo tanto, cuando la aplicación usa bibliotecas de cliente de las APIs de Google estándar para acceder a los servicios, usa esa cuenta de servicio de IAM. Google Cloud

    kubectl annotate serviceaccount sax-sa \
      iam.gke.io/gcp-service-account=sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
    

Desplegar Saxml

En esta sección, implementa el servidor de administrador de Saxml y el servidor de modelos de Saxml.

Desplegar el servidor de administrador de Saxml

  1. Crea el siguiente archivo de manifiesto sax-admin-server.yaml:

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: sax-admin-server
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: sax-admin-server
      template:
        metadata:
          labels:
            app: sax-admin-server
        spec:
          hostNetwork: false
          serviceAccountName: sax-sa
          containers:
          - name: sax-admin-server
            image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server:v1.1.0
            securityContext:
              privileged: true
            ports:
            - containerPort: 10000
            env:
            - name: GSBUCKET
              value: BUCKET_NAME
  2. Sustituye BUCKET_NAME por el Cloud Storage que has creado anteriormente:

    perl -pi -e 's|BUCKET_NAME|BUCKET_NAME|g' sax-admin-server.yaml
    
  3. Aplica el archivo de manifiesto:

    kubectl apply -f sax-admin-server.yaml
    
  4. Comprueba que el pod del servidor de administrador esté en funcionamiento:

    kubectl get deployment
    

    El resultado debería ser similar al siguiente:

    NAME               READY   UP-TO-DATE   AVAILABLE   AGE
    sax-admin-server   1/1     1            1           52s
    

Desplegar el servidor de modelos de Saxml

Las cargas de trabajo que se ejecutan en slices de TPU de varios hosts requieren un identificador de red estable para que cada Pod pueda detectar los peers del mismo slice de TPU. Para definir estos identificadores, usa IndexedJob, StatefulSet con un servicio sin interfaz gráfica o JobSet, que crea automáticamente un servicio sin interfaz gráfica para todos los trabajos que pertenecen a JobSet. Un JobSet es una API de carga de trabajo que te permite gestionar un grupo de trabajos de Kubernetes como una unidad. El caso de uso más habitual de un JobSet es el entrenamiento distribuido, pero también puedes usarlo para ejecutar cargas de trabajo por lotes.

En la siguiente sección se muestra cómo gestionar varios grupos de pods de servidor de modelos con JobSet.

  1. Instala JobSet v0.2.3 o una versión posterior.

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    Sustituye JOBSET_VERSION por la versión de JobSet. Por ejemplo, v0.2.3.

  2. Valida que el controlador JobSet se esté ejecutando en el espacio de nombres jobset-system:

    kubectl get pod -n jobset-system
    

    El resultado debería ser similar al siguiente:

    NAME                                        READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-69449d86bc-hp5r6   2/2     Running   0          2m15s
    
  3. Despliega dos servidores de modelos en dos grupos de nodos de segmentos de TPU. Guarda el siguiente archivo de manifiesto sax-model-server-set:

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: sax-model-server-set
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: sax-model-server
          replicas: 2
          template:
            spec:
              parallelism: 8
              completions: 8
              backoffLimit: 0
              template:
                spec:
                  serviceAccountName: sax-sa
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 4x8
                  containers:
                  - name: sax-model-server
                    image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server:v1.1.0
                    args: ["--port=10001","--sax_cell=/sax/test", "--platform_chip=tpuv5e"]
                    ports:
                    - containerPort: 10001
                    - containerPort: 8471
                    securityContext:
                      privileged: true
                    env:
                    - name: SAX_ROOT
                      value: "gs://BUCKET_NAME/sax-root"
                    - name: MEGASCALE_NUM_SLICES
                      value: ""
                    resources:
                      requests:
                        google.com/tpu: 4
                      limits:
                        google.com/tpu: 4
  4. Sustituye BUCKET_NAME por el Cloud Storage que has creado anteriormente:

    perl -pi -e 's|BUCKET_NAME|BUCKET_NAME|g' sax-model-server-set.yaml
    

    En este manifiesto:

    • replicas: 2 es el número de réplicas de Job. Cada trabajo representa un servidor de modelos. Por lo tanto, un grupo de 8 pods.
    • parallelism: 8 y completions: 8 son iguales al número de nodos de cada grupo de nodos.
    • backoffLimit: 0 debe ser cero para marcar el trabajo como fallido si falla algún pod.
    • ports.containerPort: 8471 es el puerto predeterminado para la comunicación de las VMs
    • name: MEGASCALE_NUM_SLICES anula la variable de entorno porque GKE no está ejecutando el entrenamiento de Multislice.
  5. Aplica el archivo de manifiesto:

    kubectl apply -f sax-model-server-set.yaml
    
  6. Verifica el estado de los pods del servidor de administración y del servidor de modelos de Saxml:

    kubectl get pods
    

    El resultado debería ser similar al siguiente:

    NAME                                              READY   STATUS    RESTARTS   AGE
    sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
    sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-1-sl8w4   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-2-hb4rk   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-3-qv67g   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-4-pzqz6   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-5-nm7mz   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-6-7br2x   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-7-4pw6z   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-0-8mlf5   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-1-h6z6w   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-2-jggtv   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-3-9v8kj   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-4-6vlb2   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-5-h689p   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-6-bgv5k   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-7-cd6gv   1/1     Running   0          24m
    

En este ejemplo, hay 16 contenedores de servidor de modelos: sax-model-server-set-sax-model-server-0-0-nj4sm y sax-model-server-set-sax-model-server-1-0-8mlf5 son los dos servidores de modelos principales de cada grupo.

Tu clúster de Saxml tiene dos servidores de modelos implementados en dos grupos de nodos de slices de TPU v5e con topologías 4x8, respectivamente.

Implementar el servidor HTTP y el balanceador de carga de Saxml

  1. Usa la siguiente imagen de servidor HTTP prediseñada. Guarda el siguiente archivo de manifiesto sax-http.yaml:

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: sax-http
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: sax-http
      template:
        metadata:
          labels:
            app: sax-http
        spec:
          hostNetwork: false
          serviceAccountName: sax-sa
          containers:
          - name: sax-http
            image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-http:v1.0.0
            ports:
            - containerPort: 8888
            env:
            - name: SAX_ROOT
              value: "gs://BUCKET_NAME/sax-root"
    ---
    apiVersion: v1
    kind: Service
    metadata:
      name: sax-http-lb
    spec:
      selector:
        app: sax-http
      ports:
      - protocol: TCP
        port: 8888
        targetPort: 8888
      type: LoadBalancer
  2. Sustituye BUCKET_NAME por el Cloud Storage que has creado anteriormente:

    perl -pi -e 's|BUCKET_NAME|BUCKET_NAME|g' sax-http.yaml
    
  3. Aplica el manifiesto sax-http.yaml:

    kubectl apply -f sax-http.yaml
    
  4. Espera a que termine de crearse el contenedor del servidor HTTP:

    kubectl get pods
    

    El resultado debería ser similar al siguiente:

    NAME                                              READY   STATUS    RESTARTS   AGE
    sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
    sax-http-65d478d987-6q7zd                         1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
    ...
    
  5. Espera a que se asigne una dirección IP externa al servicio:

    kubectl get svc
    

    El resultado debería ser similar al siguiente:

    NAME           TYPE           CLUSTER-IP    EXTERNAL-IP   PORT(S)          AGE
    sax-http-lb    LoadBalancer   10.48.11.80   10.182.0.87   8888:32674/TCP   7m36s
    

Usar Saxml

Carga, implementa y ofrece el modelo en Saxml en el slice multihost de TPU v5e:

Cargar el modelo

  1. Obtén la dirección IP del balanceador de carga de Saxml.

    LB_IP=$(kubectl get svc sax-http-lb -o jsonpath='{.status.loadBalancer.ingress[*].ip}')
    PORT="8888"
    
  2. Carga el modelo de prueba LmCloudSpmd175B en dos grupos de nodos de segmento de TPU v5e:

    curl --request POST \
    --header "Content-type: application/json" \
    -s ${LB_IP}:${PORT}/publish --data \
    '{
        "model": "/sax/test/spmd",
        "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
        "checkpoint": "None",
        "replicas": 2
    }'
    

    El modelo de prueba no tiene un punto de control optimizado, sino que los pesos se generan de forma aleatoria. La carga del modelo puede tardar hasta 10 minutos.

    El resultado debería ser similar al siguiente:

    {
        "model": "/sax/test/spmd",
        "path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
        "checkpoint": "None",
        "replicas": 2
    }
    
  3. Comprueba la preparación del modelo:

    kubectl logs sax-model-server-set-sax-model-server-0-0-nj4sm
    

    El resultado debería ser similar al siguiente:

    ...
    loading completed.
    Successfully loaded model for key: /sax/test/spmd
    

    El modelo está totalmente cargado.

  4. Obtener información sobre el modelo:

    curl --request GET \
    --header "Content-type: application/json" \
    -s ${LB_IP}:${PORT}/listcell --data \
    '{
        "model": "/sax/test/spmd"
    }'
    

    El resultado debería ser similar al siguiente:

    {
    "model": "/sax/test/spmd",
    "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
    "checkpoint": "None",
    "max_replicas": 2,
    "active_replicas": 2
    }
    

Aplicar el modelo

Enviar una solicitud de petición:

curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/generate --data \
'{
  "model": "/sax/test/spmd",
  "query": "How many days are in a week?"
}'

El resultado muestra un ejemplo de la respuesta del modelo. Es posible que esta respuesta no sea significativa porque el modelo de prueba tiene pesos aleatorios.

Anular la publicación del modelo

Ejecuta el siguiente comando para anular la publicación del modelo:

curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/unpublish --data \
'{
    "model": "/sax/test/spmd"
}'

El resultado debería ser similar al siguiente:

{
  "model": "/sax/test/spmd"
}