Implantar e inferir o Gemma usando endpoints com TPU do Model Garden e da Vertex AI

Neste tutorial, você usa o Model Garden para implantar o modelo aberto Gemma 2B em um endpoint da Vertex AI acelerado por TPU. Implante um modelo em um endpoint antes de ele ser usado para exibir previsões on-line. A implantação de um modelo associa recursos físicos ao modelo para que ele possa exibir previsões on-line com baixa latência.

Depois de implantar o modelo Gemma 2B, faça a inferência do modelo treinado usando o PredictionServiceClient para receber previsões on-line. As previsões on-line são solicitações síncronas feitas em um modelo implantado em um endpoint.

Implantar o Gemma usando o Model Garden

Você implanta o modelo Gemma 2B em um tipo de máquina ct5lp-hightpu-1t do Compute Engine otimizado para treinamento de pequena a média escala. Essa máquina tem um acelerador TPU v5e. Para mais informações sobre como treinar modelos usando TPUs, consulte Treinamento do Cloud TPU v5e.

Neste tutorial, você implanta o modelo aberto Gemma 2B ajustado para instruções usando o card de modelo no Model Garden. A versão específica do modelo é gemma2-2b-it. -it significa ajustado para seguir instruções.

O modelo Gemma 2B tem um tamanho de parâmetro menor, o que significa requisitos de recursos menores e mais flexibilidade de implantação.

  1. No console Google Cloud , acesse a página Model Garden.

    Acessar o Model Garden

  2. Clique no card de modelo Gemma 2.

    Acessar o Gemma 2

  3. Clique em Implantar para abrir o painel Implantar modelo.

  4. No painel Implantar modelo, especifique estes detalhes.

    1. Em Ambiente de implantação, clique em Vertex AI.

    2. Na seção Implantar modelo:

      1. Em ID do recurso, escolha gemma-2b-it.

      2. Em Nome do modelo e Nome do endpoint, aceite os valores padrão. Exemplo:

        • Nome do modelo: gemma2-2b-it-1234567891234
        • Nome do endpoint: gemma2-2b-it-mg-one-click-deploy

        Anote o nome do endpoint. Você vai precisar dele para encontrar o ID do endpoint usado nos exemplos de código.

    3. Na seção Configurações de implantação:

      1. Aceite a opção padrão para as configurações Básicas.

      2. Em Região, aceite o valor padrão ou escolha uma região na lista. Anote a região. Você vai precisar dele para os exemplos de código.

      3. Em Especificação da máquina, escolha a instância com suporte de TPU: ct5lp-hightpu-1t (1 TPU_V5_LITEPOD; ct5lp-hightpu-1t).

  5. Clique em Implantar. Quando a implantação for concluída, você vai receber um e-mail com detalhes sobre o novo endpoint. Para conferir os detalhes do endpoint, clique em Previsão on-line > Endpoints e selecione sua região.

    Acessar o Endpoints

Fazer inferência do Gemma 2B com o PredictionServiceClient

Depois de implantar o Gemma 2B, use o PredictionServiceClient para receber previsões on-line para o comando: "Por que o céu é azul?"

Parâmetros de código

As amostras de código PredictionServiceClient exigem que você atualize o seguinte.

  • PROJECT_ID: para encontrar o ID do projeto, siga estas etapas.

    1. Acesse a página Boas-vindas no console do Google Cloud .

      Acessar "Boas-vindas"

    2. No seletor de projetos na parte de cima da página, selecione seu projeto.

      O nome, o número e o ID do projeto aparecem depois do título Bem-vindo.

  • ENDPOINT_REGION: a região em que você implantou o endpoint.

  • ENDPOINT_ID: para encontrar o ID do endpoint, consulte-o no console ou execute o comando gcloud ai endpoints list. Você vai precisar do nome e da região do endpoint no painel Implantar modelo.

    Console

    Para conferir os detalhes do endpoint, clique em Previsão on-line > Endpoints e selecione sua região. Anote o número que aparece na coluna ID.

    Acessar o Endpoints

    gcloud

    Para conferir os detalhes do endpoint, execute o comando gcloud ai endpoints list.

    gcloud ai endpoints list \
      --region=ENDPOINT_REGION \
      --filter=display_name=ENDPOINT_NAME
    

    A saída é semelhante a esta:

    Using endpoint [https://us-central1-aiplatform.googleapis.com/]
    ENDPOINT_ID: 1234567891234567891
    DISPLAY_NAME: gemma2-2b-it-mg-one-click-deploy
    

Código de amostra

No exemplo de código da sua linguagem, atualize PROJECT_ID, ENDPOINT_REGION e ENDPOINT_ID. Em seguida, execute o código.

Python

Para saber como instalar o SDK da Vertex AI para Python, consulte Instalar o SDK da Vertex AI para Python. Para mais informações, consulte a documentação de referência da API Python.

"""
Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for TPUs
# Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
input = {"prompt": prompt}
input.update(config)

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
    client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
    f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
    endpoint=gemma2_end_point,
    instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

Node.js

Antes de testar esse exemplo, siga as instruções de configuração para Node.js no Guia de início rápido da Vertex AI sobre como usar bibliotecas de cliente. Para mais informações, consulte a documentação de referência da API Vertex AI para Node.js.

Para autenticar na Vertex AI, configure o Application Default Credentials. Para mais informações, consulte Configurar a autenticação para um ambiente de desenvolvimento local.

// Imports the Google Cloud Prediction Service Client library
const {
  // TODO(developer): Uncomment PredictionServiceClient before running the sample.
  // PredictionServiceClient,
  helpers,
} = require('@google-cloud/aiplatform');
/**
 * TODO(developer): Update these variables before running the sample.
 */
const projectId = 'your-project-id';
const endpointRegion = 'your-vertex-endpoint-region';
const endpointId = 'your-vertex-endpoint-id';

// Prompt used in the prediction
const prompt = 'Why is the sky blue?';

// Encapsulate the prompt in a correct format for TPUs
// Example format: [{prompt: 'Why is the sky blue?', temperature: 0.9}]
const input = {
  prompt,
  // Parameters for default configuration
  maxOutputTokens: 1024,
  temperature: 0.9,
  topP: 1.0,
  topK: 1,
};

// Convert input message to a list of GAPIC instances for model input
const instances = [helpers.toValue(input)];

// TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
// const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;

// Create a client
// predictionServiceClient = new PredictionServiceClient({apiEndpoint});

// Call the Gemma2 endpoint
const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;

const [response] = await predictionServiceClient.predict({
  endpoint: gemma2Endpoint,
  instances,
});

const predictions = response.predictions;
const text = predictions[0].stringValue;

console.log('Predictions:', text);

Java

Antes de testar esse exemplo, siga as instruções de configuração para Java no Guia de início rápido da Vertex AI sobre como usar bibliotecas de cliente. Para mais informações, consulte a documentação de referência da API Vertex AI para Java.

Para autenticar na Vertex AI, configure o Application Default Credentials. Para mais informações, consulte Configurar a autenticação para um ambiente de desenvolvimento local.


import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictTpu {
  private final PredictionServiceClient predictionServiceClient;

  // Constructor to inject the PredictionServiceClient
  public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
    this.predictionServiceClient = predictionServiceClient;
  }

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String endpointRegion = "us-west1";
    String endpointId = "YOUR_ENDPOINT_ID";

    PredictionServiceSettings predictionServiceSettings =
        PredictionServiceSettings.newBuilder()
            .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
            .build();
    PredictionServiceClient predictionServiceClient =
        PredictionServiceClient.create(predictionServiceSettings);
    Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);

    creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
  }

  // Demonstrates how to run inference on a Gemma2 model
  // deployed to a Vertex AI endpoint with TPU accelerators.
  public String gemma2PredictTpu(String projectId, String region,
           String endpointId) throws IOException {
    Map<String, Object> paramsMap = new HashMap<>();
    paramsMap.put("temperature", 0.9);
    paramsMap.put("maxOutputTokens", 1024);
    paramsMap.put("topP", 1.0);
    paramsMap.put("topK", 1);
    Value parameters = mapToValue(paramsMap);
    // Prompt used in the prediction
    String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
    Value.Builder instanceValue = Value.newBuilder();
    JsonFormat.parser().merge(instance, instanceValue);
    // Encapsulate the prompt in a correct format for TPUs
    // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    List<Value> instances = new ArrayList<>();
    instances.add(instanceValue.build());

    EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

    PredictResponse predictResponse = this.predictionServiceClient
        .predict(endpointName, instances, parameters);
    String textResponse = predictResponse.getPredictions(0).getStringValue();
    System.out.println(textResponse);
    return textResponse;
  }

  private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
    Gson gson = new Gson();
    String json = gson.toJson(map);
    Value.Builder builder = Value.newBuilder();
    JsonFormat.parser().merge(json, builder);
    return builder.build();
  }
}

Go

Antes de testar esse exemplo, siga as instruções de configuração para Go no Guia de início rápido da Vertex AI sobre como usar bibliotecas de cliente. Para mais informações, consulte a documentação de referência da API Vertex AI para Go.

Para autenticar na Vertex AI, configure o Application Default Credentials. Para mais informações, consulte Configurar a autenticação para um ambiente de desenvolvimento local.

import (
	"context"
	"fmt"
	"io"

	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

	"google.golang.org/protobuf/types/known/structpb"
)

// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
	ctx := context.Background()

	// Note: client can be initialized in the following way:
	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
	// if err != nil {
	// 	return fmt.Errorf("unable to create prediction client: %v", err)
	// }
	// defer client.Close()

	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
	prompt := "Why is the sky blue?"
	parameters := map[string]interface{}{
		"temperature":     0.9,
		"maxOutputTokens": 1024,
		"topP":            1.0,
		"topK":            1,
	}

	// Encapsulate the prompt in a correct format for TPUs.
	// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
	promptValue, err := structpb.NewValue(map[string]interface{}{
		"prompt":     prompt,
		"parameters": parameters,
	})
	if err != nil {
		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
		return err
	}

	req := &aiplatformpb.PredictRequest{
		Endpoint:  gemma2Endpoint,
		Instances: []*structpb.Value{promptValue},
	}

	resp, err := client.Predict(ctx, req)
	if err != nil {
		return err
	}

	prediction := resp.GetPredictions()
	value := prediction[0].GetStringValue()
	fmt.Fprintf(w, "%v", value)

	return nil
}