Entrega un LLM con TPU en GKE con JetStream y PyTorch

En esta guía, se muestra cómo entregar un modelo de lenguaje grande (LLM) que usa unidades de procesamiento tensorial (TPU) en Google Kubernetes Engine (GKE) con JetStream a través de PyTorch. En esta guía, descargarás ponderaciones de modelos en Cloud Storage y, luego, las implementarás en un clúster de Autopilot o Standard de GKE con un contenedor que ejecute JetStream.

Si necesitas la escalabilidad, la resiliencia y la rentabilidad que ofrecen las funciones de Kubernetes cuando implementas tu modelo en JetStream, esta guía es un buen punto de partida.

Esta guía está dirigida a clientes de IA generativa que usan PyTorch, usuarios nuevos o existentes de GKE, ingenieros de AA, ingenieros de MLOps (DevOps) o administradores de plataformas interesados en usar las funciones de organización de contenedores de Kubernetes para entrega de LLM.

Antecedentes

Con la entrega de un LLM con TPU en GKE con JetStream, puedes compilar una solución de entrega sólida y lista para la producción con todos los beneficios de Kubernetes administrado, incluida la rentabilidad, escalabilidad y disponibilidad mayor. En esta sección, se describen las tecnologías clave que se usan en este instructivo.

Acerca de las TPU

Las TPU son circuitos integrados personalizados específicos de aplicaciones (ASIC) de Google que se usan para acelerar el aprendizaje automático y los modelos de IA compilados con frameworks como el siguiente:TensorFlow, PyTorch yJAX.

Antes de usar las TPU en GKE, te recomendamos que completes la siguiente ruta de aprendizaje:

  1. Obtén información sobre la disponibilidad actual de la versión de TPU con la arquitectura del sistema de Cloud TPU.
  2. Obtén información sobre las TPU en GKE.

En este instructivo, se aborda la entrega de varios modelos de LLM. GKE implementa el modelo en los nodos TPUv5e de host único con topologías de TPU configuradas según los requisitos del modelo para entregar mensajes con baja latencia.

Acerca de JetStream

JetStream es un framework de entrega de inferencia de código abierto que desarrolla Google. JetStream permite la inferencia de alto rendimiento, alta capacidad de procesamiento y con optimización de memoria en TPU y GPU. JetStream proporciona optimizaciones de rendimiento avanzadas, incluidas técnicas de procesamiento por lotes, optimizaciones de la caché de KV y de cuantización continuas, para facilitar la implementación de LLM. JetStream permite que PyTorch/XLA y JAX TPU entreguen un rendimiento óptimo.

Agrupación en lotes continua

El procesamiento por lotes continuo es una técnica que agrupa de forma dinámica las solicitudes de inferencia entrantes en lotes, lo que reduce la latencia y aumenta la capacidad de procesamiento.

Cuantización de la caché de KV

La cuantización de la caché de par clave-valor implica comprimir la caché de par clave-valor que se usa en los mecanismos de atención, lo que reduce los requisitos de memoria.

Cuantización del peso en Int8

La cuantización del peso de Int8 reduce la precisión de los pesos del modelo de punto flotante de 32 bits a números enteros de 8 bits, lo que permite un procesamiento más rápido y un uso de memoria reducido.

Para obtener más información sobre estas optimizaciones, consulta los repositorios de proyectos de JetStream PyTorch y JetStream MaxText.

Acerca de PyTorch

PyTorch es un framework de aprendizaje automático de código abierto desarrollado por Meta y ahora parte del paraguas de la Linux Foundation. PyTorch proporciona funciones de alto nivel, como el procesamiento de tensores y las redes neuronales profundas.