Sirve Gemma mediante TPUs en GKE con JetStream

En este tutorial se muestra cómo servir un modelo de lenguaje grande (LLM) Gemma mediante unidades de procesamiento de tensor (TPUs) en Google Kubernetes Engine (GKE). Despliega un contenedor prediseñado con JetStream y MaxText en GKE. También puedes configurar GKE para que cargue los pesos de Gemma 7B desde Cloud Storage en el tiempo de ejecución.

Este tutorial está dirigido a ingenieros de aprendizaje automático, administradores y operadores de plataformas, y especialistas en datos e IA que quieran usar las funciones de orquestación de contenedores de Kubernetes para ofrecer LLMs. Para obtener más información sobre los roles habituales y las tareas de ejemplo a las que hacemos referencia en el contenido, consulta Roles y tareas de usuario habituales de GKE.Google Cloud

Antes de leer esta página, asegúrese de que conoce los siguientes conceptos:

Fondo

En esta sección se describen las tecnologías clave que se usan en este tutorial.

Gemma

Gemma es un conjunto de modelos de inteligencia artificial (IA) generativa ligeros y disponibles públicamente que se han lanzado con una licencia abierta. Estos modelos de IA se pueden ejecutar en tus aplicaciones, hardware, dispositivos móviles o servicios alojados. Puedes usar los modelos Gemma para generar texto, pero también puedes ajustarlos para tareas especializadas.

Para obtener más información, consulta la documentación de Gemma.

TPUs

Las TPUs son circuitos integrados para aplicaciones específicas (ASIC) desarrollados a medida por Google que se utilizan para acelerar los modelos de aprendizaje automático y de IA creados con frameworks como TensorFlow, PyTorch y JAX.

En este tutorial se explica cómo servir el modelo Gemma 7B. GKE implementa el modelo en nodos TPU v5e de un solo host con topologías de TPU configuradas en función de los requisitos del modelo para servir peticiones con baja latencia.

JetStream

JetStream es un framework de servicio de inferencia de código abierto desarrollado por Google. JetStream permite realizar inferencias de alto rendimiento, alto volumen de procesamiento y memoria optimizada en TPUs y GPUs. Ofrece optimizaciones de rendimiento avanzadas, como técnicas de cuantización y procesamiento por lotes continuo, para facilitar la implementación de LLMs. JetStream permite que el servicio de PyTorch/XLA y JAX TPU consiga un rendimiento óptimo.

Para obtener más información sobre estas optimizaciones, consulta los repositorios de proyectos JetStream PyTorch y JetStream MaxText.

MaxText

MaxText es una implementación de LLM de JAX de alto rendimiento, escalable y adaptable, creada a partir de bibliotecas de JAX de código abierto, como Flax, Orbax y Optax. La implementación de LLM solo con decodificador de MaxText está escrita en Python. Aprovecha el compilador XLA para conseguir un alto rendimiento sin necesidad de crear kernels personalizados.

Para obtener más información sobre los modelos y tamaños de parámetros más recientes que admite MaxText, consulta el repositorio del proyecto MaxText.