Diffuser Gemma à l'aide de TPU sur GKE avec JetStream

Ce tutoriel explique comment diffuser un grand modèle de langage (LLM) Gemma à l'aide des TPU (Tensor Processing Units) sur Google Kubernetes Engine (GKE). Vous déployez un conteneur prédéfini avec JetStream et MaxText sur GKE. Vous configurez également GKE pour qu'il charge les pondérations Gemma 7B depuis Cloud Storage au moment de l'exécution.

Ce tutoriel est destiné aux ingénieurs en machine learning (ML), aux administrateurs et opérateurs de plate-forme, ainsi qu'aux spécialistes des données et de l'IA qui souhaitent utiliser les fonctionnalités d'orchestration de conteneurs Kubernetes pour diffuser des LLM. Pour en savoir plus sur les rôles courants et les exemples de tâches que nous citons dans le contenuGoogle Cloud , consultez Rôles utilisateur et tâches courantes de GKE.

Avant de lire cette page, assurez-vous de connaître les éléments suivants :

Arrière-plan

Cette section décrit les principales technologies utilisées dans ce tutoriel.

Gemma

Gemma est un ensemble de modèles d'intelligence artificielle (IA) générative, légers et disponibles publiquement, publiés sous licence ouverte. Ces modèles d'IA sont disponibles pour s'exécuter dans vos applications, votre matériel, vos appareils mobiles ou vos services hébergés. Vous pouvez utiliser les modèles Gemma pour la génération de texte, mais vous pouvez également les ajuster pour des tâches spécialisées.

Pour en savoir plus, consultez la documentation Gemma.

TPU

Les TPU sont des circuits intégrés propres aux applications (Application-Specific Integrated Circuit ou ASIC), développés spécifiquement par Google et permettant d'accélérer le machine learning et les modèles d'IA créés à l'aide de frameworks tels que TensorFlow, PyTorch et JAX.

Ce tutoriel explique comment diffuser le modèle Gemma 7B. GKE déploie le modèle sur des nœuds TPUv5e à hôte unique avec des topologies TPU configurées en fonction des exigences du modèle pour diffuser des requêtes avec une faible latence.

JetStream

JetStream est un framework de diffusion d'inférences Open Source développé par Google. JetStream permet des inférences hautes performances, à haut débit et à mémoire optimisée sur les TPU et les GPU. Il fournit des optimisations de performances avancées, y compris des techniques de traitement par lot et de quantification continues, pour faciliter le déploiement de LLM. JetStream permet le serving de TPU PyTorch/XLA et JAX pour obtenir des performances optimales.

Pour en savoir plus sur ces optimisations, consultez les dépôts de projets JetStream PyTorch et JetStream MaxText.

MaxText

MaxText est une implémentation LLM JAX performante, évolutive et adaptable, basée sur des bibliothèques JAX Open Source telles que Flax, Orbax et Optax. L'implémentation LLM uniquement décodeur de MaxText est écrite en Python. Elle exploite fortement le compilateur XLA pour atteindre de hautes performances sans avoir à créer de noyau personnalisé.

Pour en savoir plus sur les derniers modèles et tailles de paramètres compatibles avec MaxText, consultez le dépôt du projet MaxText.