Implemente e faça inferência do Gemma através do Model Garden e dos pontos finais suportados pela TPU do Vertex AI

Neste tutorial, vai usar o Model Garden para implementar o modelo aberto Gemma 2B num ponto final do Vertex AI suportado por uma TPU. Tem de implementar um modelo num ponto final antes de poder usar esse modelo para publicar previsões online. A implementação de um modelo associa recursos físicos ao modelo para que possa publicar previsões online com baixa latência.

Depois de implementar o modelo Gemma 2B, infere o modelo preparado usando o PredictionServiceClient para obter previsões online. As previsões online são pedidos síncronos feitos a um modelo implementado num ponto final.

Implemente o Gemma através do Model Garden

Implementa o modelo Gemma 2B num tipo de máquina do Compute Engine ct5lp-hightpu-1t otimizado para o treino de pequena a média escala. Esta máquina tem um acelerador TPU v5e. Para mais informações sobre a preparação de modelos com TPUs, consulte o artigo Preparação de Cloud TPU v5e.

Neste tutorial, implementa o modelo aberto Gemma 2B com ajuste fino de instruções usando o cartão do modelo no Model Garden. A versão específica do modelo é gemma2-2b-it. -it significa ajustado para instruções.

O modelo Gemma 2B tem um tamanho de parâmetro inferior, o que significa requisitos de recursos mais baixos e maior flexibilidade de implementação.

  1. Na Google Cloud consola, aceda à página Model Garden.

    Aceda ao Model Garden

  2. Clique no cartão do modelo Gemma 2.

    Aceda ao Gemma 2

  3. Clique em Implementar para abrir o painel Implementar modelo.

  4. No painel Implementar modelo, especifique estes detalhes.

    1. Em Ambiente de implementação, clique em Vertex AI.

    2. Na secção Implementar modelo:

      1. Para ID do recurso, escolha gemma-2b-it.

      2. Para Nome do modelo e Nome do ponto final, aceite os valores predefinidos. Por exemplo:

        • Nome do modelo: gemma2-2b-it-1234567891234
        • Nome do ponto final: gemma2-2b-it-mg-one-click-deploy

        Tome nota do nome do ponto final. Precisa dele para encontrar o ID do ponto final usado nos exemplos de código.

    3. Na secção Definições de implementação:

      1. Aceite a opção predefinida para as definições Básicas.

      2. Para Região, aceite o valor predefinido ou escolha uma região na lista. Tome nota da região. Vai precisar dele para os exemplos de código.

      3. Para Especificações da máquina, escolha a instância suportada pela TPU: ct5lp-hightpu-1t (1 TPU_V5_LITEPOD; ct5lp-hightpu-1t).

  5. Clique em Implementar. Quando a implementação estiver concluída, recebe um email que contém detalhes sobre o seu novo ponto final. Também pode ver os detalhes do ponto final clicando em Previsão online > Pontos finais e selecionando a sua região.

    Aceda a Pontos finais

Faça inferência do Gemma 2B com o PredictionServiceClient

Depois de implementar o Gemma 2B, usa a PredictionServiceClient para obter previsões online para o comando: "Porque é que o céu é azul?"

Parâmetros de código

Os PredictionServiceClient exemplos de código requerem que atualize o seguinte.

  • PROJECT_ID: para encontrar o ID do projeto, siga estes passos.

    1. Aceda à página Boas-vindas na Google Cloud consola.

      Aceder a Boas-vindas

    2. No seletor de projetos na parte superior da página, selecione o seu projeto.

      O nome do projeto, o número do projeto e o ID do projeto aparecem após o cabeçalho Bem-vindo.

  • ENDPOINT_REGION: esta é a região onde implementou o ponto final.

  • ENDPOINT_ID: para encontrar o ID do ponto final, veja-o na consola ou execute o comando gcloud ai endpoints list. Precisa do nome do ponto final e da região do painel Implementar modelo.

    Consola

    Pode ver os detalhes do ponto final clicando em Previsão online > Pontos finais e selecionando a sua região. Tenha em atenção o número apresentado na coluna ID.

    Aceda a Pontos finais

    gcloud

    Pode ver os detalhes do ponto final executando o comando gcloud ai endpoints list.

    gcloud ai endpoints list \
      --region=ENDPOINT_REGION \
      --filter=display_name=ENDPOINT_NAME
    

    O resultado tem o seguinte aspeto.

    Using endpoint [https://us-central1-aiplatform.googleapis.com/]
    ENDPOINT_ID: 1234567891234567891
    DISPLAY_NAME: gemma2-2b-it-mg-one-click-deploy
    

Exemplo de código

No exemplo de código para o seu idioma, atualize os valores PROJECT_ID, ENDPOINT_REGION e ENDPOINT_ID. Em seguida, execute o código.

Python

Para saber como instalar ou atualizar o SDK Vertex AI para Python, consulte o artigo Instale o SDK Vertex AI para Python. Para mais informações, consulte a Python documentação de referência da API.

"""
Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for TPUs
# Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
input = {"prompt": prompt}
input.update(config)

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
    client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
    f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
    endpoint=gemma2_end_point,
    instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

Node.js

Antes de experimentar este exemplo, siga as Node.jsinstruções de configuração no início rápido do Vertex AI com bibliotecas de cliente. Para mais informações, consulte a documentação de referência da API Node.js Vertex AI.

Para se autenticar no Vertex AI, configure as Credenciais padrão da aplicação. Para mais informações, consulte o artigo Configure a autenticação para um ambiente de desenvolvimento local.

// Imports the Google Cloud Prediction Service Client library
const {
  // TODO(developer): Uncomment PredictionServiceClient before running the sample.
  // PredictionServiceClient,
  helpers,
} = require('@google-cloud/aiplatform');
/**
 * TODO(developer): Update these variables before running the sample.
 */
const projectId = 'your-project-id';
const endpointRegion = 'your-vertex-endpoint-region';
const endpointId = 'your-vertex-endpoint-id';

// Prompt used in the prediction
const prompt = 'Why is the sky blue?';

// Encapsulate the prompt in a correct format for TPUs
// Example format: [{prompt: 'Why is the sky blue?', temperature: 0.9}]
const input = {
  prompt,
  // Parameters for default configuration
  maxOutputTokens: 1024,
  temperature: 0.9,
  topP: 1.0,
  topK: 1,
};

// Convert input message to a list of GAPIC instances for model input
const instances = [helpers.toValue(input)];

// TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
// const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;

// Create a client
// predictionServiceClient = new PredictionServiceClient({apiEndpoint});

// Call the Gemma2 endpoint
const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;

const [response] = await predictionServiceClient.predict({
  endpoint: gemma2Endpoint,
  instances,
});

const predictions = response.predictions;
const text = predictions[0].stringValue;

console.log('Predictions:', text);

Java

Antes de experimentar este exemplo, siga as Javainstruções de configuração no início rápido do Vertex AI com bibliotecas de cliente. Para mais informações, consulte a documentação de referência da API Java Vertex AI.

Para se autenticar no Vertex AI, configure as Credenciais padrão da aplicação. Para mais informações, consulte o artigo Configure a autenticação para um ambiente de desenvolvimento local.


import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictTpu {
  private final PredictionServiceClient predictionServiceClient;

  // Constructor to inject the PredictionServiceClient
  public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
    this.predictionServiceClient = predictionServiceClient;
  }

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String endpointRegion = "us-west1";
    String endpointId = "YOUR_ENDPOINT_ID";

    PredictionServiceSettings predictionServiceSettings =
        PredictionServiceSettings.newBuilder()
            .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
            .build();
    PredictionServiceClient predictionServiceClient =
        PredictionServiceClient.create(predictionServiceSettings);
    Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);

    creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
  }

  // Demonstrates how to run inference on a Gemma2 model
  // deployed to a Vertex AI endpoint with TPU accelerators.
  public String gemma2PredictTpu(String projectId, String region,
           String endpointId) throws IOException {
    Map<String, Object> paramsMap = new HashMap<>();
    paramsMap.put("temperature", 0.9);
    paramsMap.put("maxOutputTokens", 1024);
    paramsMap.put("topP", 1.0);
    paramsMap.put("topK", 1);
    Value parameters = mapToValue(paramsMap);
    // Prompt used in the prediction
    String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
    Value.Builder instanceValue = Value.newBuilder();
    JsonFormat.parser().merge(instance, instanceValue);
    // Encapsulate the prompt in a correct format for TPUs
    // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    List<Value> instances = new ArrayList<>();
    instances.add(instanceValue.build());

    EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

    PredictResponse predictResponse = this.predictionServiceClient
        .predict(endpointName, instances, parameters);
    String textResponse = predictResponse.getPredictions(0).getStringValue();
    System.out.println(textResponse);
    return textResponse;
  }

  private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
    Gson gson = new Gson();
    String json = gson.toJson(map);
    Value.Builder builder = Value.newBuilder();
    JsonFormat.parser().merge(json, builder);
    return builder.build();
  }
}

Go

Antes de experimentar este exemplo, siga as Goinstruções de configuração no início rápido do Vertex AI com bibliotecas de cliente. Para mais informações, consulte a documentação de referência da API Go Vertex AI.

Para se autenticar no Vertex AI, configure as Credenciais padrão da aplicação. Para mais informações, consulte o artigo Configure a autenticação para um ambiente de desenvolvimento local.

import (
	"context"
	"fmt"
	"io"

	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

	"google.golang.org/protobuf/types/known/structpb"
)

// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
	ctx := context.Background()

	// Note: client can be initialized in the following way:
	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
	// if err != nil {
	// 	return fmt.Errorf("unable to create prediction client: %v", err)
	// }
	// defer client.Close()

	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
	prompt := "Why is the sky blue?"
	parameters := map[string]interface{}{
		"temperature":     0.9,
		"maxOutputTokens": 1024,
		"topP":            1.0,
		"topK":            1,
	}

	// Encapsulate the prompt in a correct format for TPUs.
	// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
	promptValue, err := structpb.NewValue(map[string]interface{}{
		"prompt":     prompt,
		"parameters": parameters,
	})
	if err != nil {
		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
		return err
	}

	req := &aiplatformpb.PredictRequest{
		Endpoint:  gemma2Endpoint,
		Instances: []*structpb.Value{promptValue},
	}

	resp, err := client.Predict(ctx, req)
	if err != nil {
		return err
	}

	prediction := resp.GetPredictions()
	value := prediction[0].GetStringValue()
	fmt.Fprintf(w, "%v", value)

	return nil
}