Desplegar y usar Gemma para inferencias con Model Garden y endpoints respaldados por TPU de Vertex AI

En este tutorial, usarás Model Garden para desplegar el modelo abierto Gemma 2B en un endpoint de Vertex AI con TPU. Debes desplegar un modelo en un endpoint para poder usarlo y ofrecer predicciones online. Al desplegar un modelo, se asocian recursos físicos a este para que pueda ofrecer predicciones online con baja latencia.

Una vez que hayas desplegado el modelo Gemma 2B, podrás inferir el modelo entrenado mediante PredictionServiceClient para obtener predicciones online. Las predicciones online son solicitudes síncronas que se envían a un modelo desplegado en un endpoint.

Desplegar Gemma con Model Garden

Implementa el modelo Gemma 2B en un ct5lp-hightpu-1t tipo de máquina de Compute Engine optimizado para el entrenamiento a pequeña y mediana escala. Esta máquina tiene un acelerador TPU v5e. Para obtener más información sobre cómo entrenar modelos con TPUs, consulta el artículo sobre el entrenamiento con la versión 5e de TPU de Cloud.

En este tutorial, desplegarás el modelo abierto Gemma 2B ajustado para seguir instrucciones mediante la tarjeta de modelo de Model Garden. La versión específica del modelo es gemma2-2b-it. -it significa ajustado para instrucciones.

El modelo Gemma 2B tiene un tamaño de parámetro más pequeño, lo que significa que requiere menos recursos y ofrece más flexibilidad de implementación.

  1. En la Google Cloud consola, ve a la página Model Garden.

    Ir a Model Garden

  2. Haz clic en la tarjeta del modelo Gemma 2.

    Ir a Gemma 2

  3. Haga clic en Implementar para abrir el panel Implementar modelo.

  4. En el panel Implementar modelo, especifica estos detalles.

    1. En Entorno de implementación, haz clic en Vertex AI.

    2. En la sección Desplegar modelo:

      1. En ID de recurso, elige gemma-2b-it.

      2. En Nombre del modelo y Nombre del endpoint, acepta los valores predeterminados. Por ejemplo:

        • Nombre del modelo: gemma2-2b-it-1234567891234
        • Nombre del endpoint: gemma2-2b-it-mg-one-click-deploy

        Anota el nombre del endpoint. Lo necesitará para encontrar el ID de endpoint que se usa en los ejemplos de código.

    3. En la sección Ajustes del despliegue:

      1. Acepta la opción predeterminada de Básico.

      2. En Región, acepta el valor predeterminado o elige una región de la lista. Anota la región. Lo necesitarás para los ejemplos de código.

      3. En Especificación de la máquina, elige la instancia con TPU: ct5lp-hightpu-1t (1 TPU_V5_LITEPOD; ct5lp-hightpu-1t).

  5. Haz clic en Desplegar. Cuando finalice la implementación, recibirás un correo con los detalles de tu nuevo endpoint. También puedes ver los detalles del endpoint haciendo clic en Predicción online > Endpoints y seleccionando tu región.

    Ir a Endpoints

Inferencia de Gemma 2B con PredictionServiceClient

Después de desplegar Gemma 2B, usas la PredictionServiceClient para obtener predicciones online de la petición "¿Por qué el cielo es azul?".

Parámetros de código

En los ejemplos de código de PredictionServiceClient, debes actualizar lo siguiente.

  • PROJECT_ID: Para encontrar el ID de tu proyecto, sigue estos pasos.

    1. Ve a la página Bienvenida de la Google Cloud consola.

      Ir a Bienvenida

    2. En el selector de proyectos de la parte superior de la página, selecciona tu proyecto.

      El nombre, el número y el ID del proyecto aparecen después del encabezado Bienvenido.

  • ENDPOINT_REGION: es la región en la que has implementado el endpoint.

  • ENDPOINT_ID: Para encontrar tu ID de endpoint, consúltalo en la consola o ejecuta el comando gcloud ai endpoints list. Necesitarás el nombre y la región del endpoint del panel Implementar modelo.

    Consola

    Para ver los detalles del endpoint, haga clic en Predicción online > Endpoints y seleccione su región. Fíjate en el número que aparece en la columna ID.

    Ir a Endpoints

    gcloud

    Para ver los detalles del endpoint, ejecuta el comando gcloud ai endpoints list.

    gcloud ai endpoints list \
      --region=ENDPOINT_REGION \
      --filter=display_name=ENDPOINT_NAME
    

    La salida tiene este aspecto.

    Using endpoint [https://us-central1-aiplatform.googleapis.com/]
    ENDPOINT_ID: 1234567891234567891
    DISPLAY_NAME: gemma2-2b-it-mg-one-click-deploy
    

Código de muestra

En el código de ejemplo de tu idioma, actualiza PROJECT_ID, ENDPOINT_REGION y ENDPOINT_ID. A continuación, ejecuta el código.

Python

Para saber cómo instalar o actualizar el SDK de Vertex AI para Python, consulta Instalar el SDK de Vertex AI para Python. Para obtener más información, consulta la documentación de referencia de la API Python.

"""
Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for TPUs
# Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
input = {"prompt": prompt}
input.update(config)

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
    client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
    f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
    endpoint=gemma2_end_point,
    instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

Node.js

Antes de probar este ejemplo, sigue las Node.js instrucciones de configuración de la guía de inicio rápido de Vertex AI con bibliotecas de cliente. Para obtener más información, consulta la documentación de referencia de la API Node.js de Vertex AI.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Para obtener más información, consulta el artículo Configurar la autenticación en un entorno de desarrollo local.

// Imports the Google Cloud Prediction Service Client library
const {
  // TODO(developer): Uncomment PredictionServiceClient before running the sample.
  // PredictionServiceClient,
  helpers,
} = require('@google-cloud/aiplatform');
/**
 * TODO(developer): Update these variables before running the sample.
 */
const projectId = 'your-project-id';
const endpointRegion = 'your-vertex-endpoint-region';
const endpointId = 'your-vertex-endpoint-id';

// Prompt used in the prediction
const prompt = 'Why is the sky blue?';

// Encapsulate the prompt in a correct format for TPUs
// Example format: [{prompt: 'Why is the sky blue?', temperature: 0.9}]
const input = {
  prompt,
  // Parameters for default configuration
  maxOutputTokens: 1024,
  temperature: 0.9,
  topP: 1.0,
  topK: 1,
};

// Convert input message to a list of GAPIC instances for model input
const instances = [helpers.toValue(input)];

// TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
// const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;

// Create a client
// predictionServiceClient = new PredictionServiceClient({apiEndpoint});

// Call the Gemma2 endpoint
const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;

const [response] = await predictionServiceClient.predict({
  endpoint: gemma2Endpoint,
  instances,
});

const predictions = response.predictions;
const text = predictions[0].stringValue;

console.log('Predictions:', text);

Java

Antes de probar este ejemplo, sigue las Java instrucciones de configuración de la guía de inicio rápido de Vertex AI con bibliotecas de cliente. Para obtener más información, consulta la documentación de referencia de la API Java de Vertex AI.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Para obtener más información, consulta el artículo Configurar la autenticación en un entorno de desarrollo local.


import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictTpu {
  private final PredictionServiceClient predictionServiceClient;

  // Constructor to inject the PredictionServiceClient
  public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
    this.predictionServiceClient = predictionServiceClient;
  }

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String endpointRegion = "us-west1";
    String endpointId = "YOUR_ENDPOINT_ID";

    PredictionServiceSettings predictionServiceSettings =
        PredictionServiceSettings.newBuilder()
            .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
            .build();
    PredictionServiceClient predictionServiceClient =
        PredictionServiceClient.create(predictionServiceSettings);
    Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);

    creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
  }

  // Demonstrates how to run inference on a Gemma2 model
  // deployed to a Vertex AI endpoint with TPU accelerators.
  public String gemma2PredictTpu(String projectId, String region,
           String endpointId) throws IOException {
    Map<String, Object> paramsMap = new HashMap<>();
    paramsMap.put("temperature", 0.9);
    paramsMap.put("maxOutputTokens", 1024);
    paramsMap.put("topP", 1.0);
    paramsMap.put("topK", 1);
    Value parameters = mapToValue(paramsMap);
    // Prompt used in the prediction
    String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
    Value.Builder instanceValue = Value.newBuilder();
    JsonFormat.parser().merge(instance, instanceValue);
    // Encapsulate the prompt in a correct format for TPUs
    // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    List<Value> instances = new ArrayList<>();
    instances.add(instanceValue.build());

    EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

    PredictResponse predictResponse = this.predictionServiceClient
        .predict(endpointName, instances, parameters);
    String textResponse = predictResponse.getPredictions(0).getStringValue();
    System.out.println(textResponse);
    return textResponse;
  }

  private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
    Gson gson = new Gson();
    String json = gson.toJson(map);
    Value.Builder builder = Value.newBuilder();
    JsonFormat.parser().merge(json, builder);
    return builder.build();
  }
}

Go

Antes de probar este ejemplo, sigue las Go instrucciones de configuración de la guía de inicio rápido de Vertex AI con bibliotecas de cliente. Para obtener más información, consulta la documentación de referencia de la API Go de Vertex AI.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Para obtener más información, consulta el artículo Configurar la autenticación en un entorno de desarrollo local.

import (
	"context"
	"fmt"
	"io"

	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

	"google.golang.org/protobuf/types/known/structpb"
)

// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
	ctx := context.Background()

	// Note: client can be initialized in the following way:
	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
	// if err != nil {
	// 	return fmt.Errorf("unable to create prediction client: %v", err)
	// }
	// defer client.Close()

	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
	prompt := "Why is the sky blue?"
	parameters := map[string]interface{}{
		"temperature":     0.9,
		"maxOutputTokens": 1024,
		"topP":            1.0,
		"topK":            1,
	}

	// Encapsulate the prompt in a correct format for TPUs.
	// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
	promptValue, err := structpb.NewValue(map[string]interface{}{
		"prompt":     prompt,
		"parameters": parameters,
	})
	if err != nil {
		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
		return err
	}

	req := &aiplatformpb.PredictRequest{
		Endpoint:  gemma2Endpoint,
		Instances: []*structpb.Value{promptValue},
	}

	resp, err := client.Predict(ctx, req)
	if err != nil {
		return err
	}

	prediction := resp.GetPredictions()
	value := prediction[0].GetStringValue()
	fmt.Fprintf(w, "%v", value)

	return nil
}