Disponibilizar um LLM usando TPUs no GKE com o JetStream e o PyTorch

Este guia mostra como exibir um modelo de linguagem grande (LLM) usando Unidades de Processamento de Tensor (TPUs) no Google Kubernetes Engine (GKE) com JetStream através do PyTorch. Neste guia, você baixa pesos de modelo para o Cloud Storage e os implanta em um cluster do GKE Autopilot ou Standard usando um contêiner que executa o JetStream.

Este guia é um bom ponto de partida se você precisar da escalonabilidade, resiliência e economia oferecidos pelos recursos do Kubernetes ao implantar o modelo no JetStream.

Este guia é destinado a clientes de IA generativa que usam o PyTorch, usuários novos ou atuais do GKE, engenheiros de ML, engenheiros de MLOps (DevOps) ou administradores de plataformas interessados em usar os recursos de orquestração de contêineres do Kubernetes para veiculação de LLMs.

Contexto

Ao disponibilizar o LLM usando TPUs no GKE com o JetStream, é possível criar uma solução de exibição robusta e pronta para produção com todos os benefícios do Kubernetes gerenciado, incluindo economia, escalonabilidade e maior disponibilidade. Esta seção descreve as principais tecnologias usadas neste tutorial.

Sobre TPUs

TPUs são circuitos integrados de aplicação específica (ASICs, na sigla em inglês) desenvolvidos especialmente pelo Google. Eles são usados para acelerar modelos de machine learning e de IA criados com o uso de frameworks comoTensorFlow , PyTorch eJAX.

Antes de usar TPUs no GKE, recomendamos que você conclua o seguinte programa de aprendizado:

  1. Saiba mais sobre a disponibilidade atual da versão da TPU com a arquitetura do sistema do Cloud TPU.
  2. Saiba mais sobre TPUs no GKE.

Este tutorial aborda a disponibilização de vários modelos de LLM. O GKE implanta o modelo em nós TPUv5e de host único com topologias de TPU configuradas com base nos requisitos do modelo para exibir prompts com baixa latência.

Sobre o JetStream

O JetStream é um framework de veiculação de inferência de código aberto desenvolvido pelo Google. O JetStream permite a inferência de alto desempenho, alta capacidade e otimização de memória em TPUs e GPUs. O JetStream oferece otimizações avançadas de desempenho, incluindo lotes contínuos, otimizações de cache KV e técnicas de quantização, para facilitar a implantação de LLMs. O JetStream permite a veiculação de TPU do PyTorch/XLA e do JAX para alcançar o desempenho ideal.

Lotes contínuos

O agrupamento contínuo é uma técnica que agrupa dinamicamente as solicitações de inferência recebidas em lotes, reduzindo a latência e aumentando a capacidade de processamento.

Quantização de cache KV

A quantização do cache KV envolve a compactação do cache de valor-chave usado em mecanismos de atenção, reduzindo os requisitos de memória.

Quantização de peso Int8

A quantização de peso Int8 reduz a precisão dos pesos do modelo de ponto flutuante de 32 bits para inteiros de 8 bits, resultando em computação mais rápida e uso reduzido da memória.

Para saber mais sobre essas otimizações, consulte os repositórios de projetos JetStream PyTorch e JetStream MaxText (links em inglês).

Sobre o PyTorch

O PyTorch é um framework de aprendizado de máquina de código aberto desenvolvido pela Meta e agora faz parte da Linux Foundation. O PyTorch fornece recursos de alto nível, como computação de tensor e redes neurais profundas.