Diffuser un LLM à l'aide de TPU sur GKE avec JetStream et PyTorch

Ce guide vous explique comment diffuser un grand modèle de langage (LLM) à l'aide des TPU (Tensor Processing Units) sur Google Kubernetes Engine (GKE) avec JetStream via PyTorch. Dans ce guide, vous téléchargez des pondérations de modèle dans Cloud Storage et les déployez sur un cluster GKE Autopilot ou Standard à l'aide d'un conteneur exécutant JetStream

Si vous avez besoin de l'évolutivité, de la résilience et de la rentabilité offertes par les fonctionnalités de Kubernetes lors du déploiement de votre modèle sur JetStream, ce guide est un bon point de départ.

Ce guide est destiné aux clients d'IA générative qui utilisent PyTorch, aux utilisateurs nouveaux ou existants de GKE, aux ingénieurs en ML, aux ingénieurs MLOps (DevOps) ou aux administrateurs de plate-forme qui souhaitent utiliser les fonctionnalités d'orchestration de conteneurs Kubernetes pour diffuser des LLM.

Contexte

En diffusant un LLM à l'aide de TPU sur GKE avec JetStream, vous pouvez créer une solution de diffusion robuste et prête pour la production avec tous les avantages de la plate-forme Kubernetes gérée, y compris en termes de rentabilité, évolutivité et haute disponibilité. Cette section décrit les principales technologies utilisées dans ce tutoriel.

À propos des TPU

Les TPU sont des circuits intégrés propres aux applications (Application-Specific Integrated Circuit ou ASIC), développés spécifiquement par Google et permettant d'accélérer le machine learning et les modèles d'IA créés à l'aide de frameworks tels que TensorFlow, PyTorch et JAX.

Avant d'utiliser des TPU dans GKE, nous vous recommandons de suivre le parcours de formation suivant :

  1. Découvrez la disponibilité actuelle des versions de TPU avec l'architecture système de Cloud TPU.
  2. Apprenez-en plus sur les TPU dans GKE.

Ce tutoriel explique comment diffuser différents modèles LLM. GKE déploie le modèle sur des nœuds TPUv5e à hôte unique avec des topologies TPU configurées en fonction des exigences du modèle pour diffuser des requêtes avec une faible latence.

À propos de JetStream

JetStream est un framework de diffusion d'inférences Open Source développé par Google. JetStream permet des inférences hautes performances, à haut débit et à mémoire optimisée sur les TPU et les GPU. JetStream fournit des optimisations de performances avancées, y compris des techniques de traitement par lot continu, d'optimisation du cache KV et de quantification, pour faciliter le déploiement de LLM. JetStream permet aux services TPU PyTorch/XLA et JAX d'atteindre des performances optimales.

Traitement par lots continu

Le traitement par lot continu est une technique qui regroupe dynamiquement les requêtes d'inférence entrantes en lots, ce qui réduit la latence et augmente le débit.

Quantification du cache KV

La quantification du cache KV consiste à compresser le cache clé-valeur utilisé dans les mécanismes d'attention, ce qui réduit les besoins en mémoire.

Quantification des poids Int8

La quantification des poids Int8 réduit la précision des poids du modèle de 32 bits à virgule flottante à des entiers de 8 bits, ce qui accélère le calcul et réduit l'utilisation de mémoire.

Pour en savoir plus sur ces optimisations, consultez les dépôts de projets JetStream PyTorch et JetStream MaxText.

À propos de PyTorch

PyTorch est un framework de machine learning Open Source développé par Meta et qui fait désormais partie de la Linux Foundation. PyTorch fournit des fonctionnalités de haut niveau, telles que le calcul Tensor et les réseaux de neurones profonds.