Servir LLMs mediante TPUs de varios hosts en GKE con JetStream y Pathways

En esta guía se explica cómo servir modelos de lenguaje extensos (LLMs) de vanguardia, como Llama 3.1 405B, en Google Kubernetes Engine (GKE) mediante unidades de procesamiento tensorial (TPUs) en varios nodos.

En esta guía se muestra cómo usar tecnologías portátiles de código abierto (Kubernetes, JetStream, Pathways on Cloud y la API LeaderWorkerSet [LWS]) para desplegar y servir cargas de trabajo de IA y aprendizaje automático en GKE. Para ello, se aprovechan el control granular, la escalabilidad, la resiliencia, la portabilidad y la rentabilidad de GKE.

Fondo

Los modelos de lenguaje extenso han aumentado de tamaño y ya no caben en una sola porción de TPU de host. Para la inferencia de aprendizaje automático, puedes usar Pathways en Cloud para ejecutar inferencias de varios hosts a gran escala en GKE en varios nodos de TPU interconectados. En esta guía, se explica cómo aprovisionar un clúster de GKE con las slices de TPU de varios hosts, usar los archivos binarios de Pathways on Cloud, iniciar el servidor JetStream con el framework MaxText y hacer solicitudes de inferencia de varios hosts.

Si sirves un LLM mediante TPUs en GKE con JetStream, MaxText y Pathways, puedes crear una solución de servicio estable y lista para producción con todas las ventajas de Kubernetes gestionado, como la rentabilidad, la escalabilidad y la mayor disponibilidad. En esta sección se describen las tecnologías clave que se usan en este tutorial.

Acerca de las TPUs

Las TPUs son circuitos integrados para aplicaciones específicas (ASIC) desarrollados a medida por Google que se utilizan para acelerar modelos de aprendizaje automático e IA creados con frameworks como TensorFlow, PyTorch y JAX.

Antes de usar las TPUs en GKE, te recomendamos que completes el siguiente plan de formación:

  1. Consulta la arquitectura del sistema de las TPU de Cloud para obtener información sobre la disponibilidad de las versiones actuales de las TPU.
  2. Consulta información sobre las TPUs en GKE.

En este tutorial se explica cómo servir el modelo Llama 3.1-405B. GKE implementa el modelo en nodos de TPU v6e de varios hosts con topologías de TPU configuradas en función de los requisitos del modelo para servir peticiones con baja latencia.

Rutas de aprendizaje en Cloud

Pathways es una capa de orquestación a gran escala para aceleradores. Pathways se ha diseñado específicamente para permitir la exploración de nuevos sistemas e ideas de investigación de aprendizaje automático, al tiempo que mantiene el rendimiento de vanguardia de los modelos actuales. Pathways permite que un solo proceso de cliente de JAX coordine la computación en una o varias grandes porciones de TPU, lo que agiliza las computaciones de aprendizaje automático que abarcan cientos o miles de chips de TPU.

JetStream

JetStream es un framework de servicio de inferencia de código abierto desarrollado por Google. JetStream permite realizar inferencias de alto rendimiento, alto volumen de procesamiento y memoria optimizada en TPUs y GPUs. JetStream ofrece optimizaciones de rendimiento avanzadas, como la creación de minilotes continua, la optimización de la caché de valores de clave y las técnicas de cuantización, para facilitar la implementación de LLMs. JetStream permite que PyTorch/XLA y JAX TPU sirvan para optimizar el rendimiento.

MaxText

MaxText es una implementación de LLM de JAX de alto rendimiento, escalable y adaptable, creada a partir de bibliotecas de JAX de código abierto, como Flax, Orbax y Optax. La implementación de LLM solo con decodificador de MaxText está escrita en Python. Aprovecha el compilador XLA para conseguir un alto rendimiento sin necesidad de crear kernels personalizados.

Para obtener más información sobre los modelos y tamaños de parámetros más recientes que admite MaxText, consulta el repositorio del proyecto MaxText.

Llama 3.1 405B

Llama 3.1 405B es un modelo de lenguaje extenso de Meta diseñado para llevar a cabo diversas tareas de procesamiento del lenguaje natural, como la generación de texto, la traducción y la respuesta a preguntas. GKE ofrece la infraestructura necesaria para satisfacer las necesidades de entrenamiento y servicio distribuidos de modelos de esta escala.

Para obtener más información, consulta la documentación de Llama.

Arquitectura

En esta sección se describe la arquitectura de GKE que se usa en este tutorial. La arquitectura incluye un clúster de GKE Standard que aprovisiona TPUs y aloja componentes de JetStream y Pathways para desplegar y servir el modelo.

En el siguiente diagrama se muestran los componentes de esta arquitectura:

Arquitectura de un clúster de GKE con un grupo de nodos de TPU de varios hosts que contiene los componentes JetStream y Pathways.

Esta arquitectura incluye los siguientes componentes:

  • Un clúster regional de GKE Standard.
  • Un grupo de nodos de segmento de TPU multihost que aloja el despliegue de JetStream y los componentes de Pathways.
  • Pathways resource manager gestiona los recursos del acelerador y coordina la asignación de aceleradores a los trabajos de los usuarios.
  • El Pathways client se coordina con el Pathways resource manager para determinar dónde se colocan los programas compilados para su ejecución.
  • El Pathways worker se ejecuta y realiza cálculos en máquinas aceleradoras, y envía datos a tu carga de trabajo a través del servidor proxy IFRT.
  • IFRT proxy client implementa la API Interim Framework Runtime (IFRT) de OSS y actúa como puente de comunicación entre tu carga de trabajo y los componentes de Pathways.
  • El IFRT proxy server recibe solicitudes del IFRT proxy client y las reenvía al Pathways client, distribuyendo el trabajo.
  • El contenedor JetStream-Pathways proporciona un servidor de inferencia basado en JAX que recibe solicitudes de inferencia y delega sus procesos de ejecución en Pathways workers.
  • El componente Service distribuye el tráfico entrante a todas las réplicas de JetStream HTTP.
  • JetStream HTTP es un servidor HTTP que acepta solicitudes como envoltorio del formato requerido de JetStream y las envía al cliente GRPC de JetStream.