Implementa y realiza inferencias con Gemma usando Model Garden y extremos compatibles con TPU de Vertex AI

En este instructivo, usarás Model Garden para implementar el modelo abierto Gemma 2B en un extremo de Vertex AI respaldado por TPU. Debes implementar un modelo en un extremo antes de que se pueda usar para entregar predicciones en línea. La implementación de un modelo asocia recursos físicos con el modelo para que pueda entregar predicciones en línea con baja latencia.

Después de implementar el modelo de Gemma 2B, puedes realizar la inferencia del modelo entrenado con PredictionServiceClient para obtener predicciones en línea. Las predicciones en línea son solicitudes síncronas realizadas en un modelo que se implementa en un extremo.

Implementa Gemma con Model Garden

Implementarás el modelo Gemma 2B en un tipo de máquina ct5lp-hightpu-1t de Compute Engine optimizado para el entrenamiento a pequeña y mediana escala. Esta máquina tiene un acelerador de TPU v5e. Para obtener más información sobre el entrenamiento de modelos con TPU, consulta Entrenamiento de Cloud TPU v5e.

En este instructivo, implementarás el modelo abierto Gemma 2B ajustado para instrucciones con la tarjeta de modelo en Model Garden. La versión específica del modelo es gemma2-2b-it; -it significa ajustado según las instrucciones.

El modelo Gemma 2B tiene un tamaño de parámetro más bajo, lo que significa menores requisitos de recursos y más flexibilidad de implementación.

  1. En la consola de Google Cloud , ve a la página Model Garden.

    Ir a Model Garden

  2. Haz clic en la tarjeta de modelo Gemma 2.

    Ir a Gemma 2

  3. Haz clic en Implementar para abrir el panel Implementar modelo.

  4. En el panel Deploy model, especifica estos detalles.

    1. En Entorno de implementación, haz clic en Vertex AI.

    2. En la sección Implementar modelo, haz lo siguiente:

      1. En ID del recurso, elige gemma-2b-it.

      2. En Nombre del modelo y Nombre del extremo, acepta los valores predeterminados. Por ejemplo:

        • Nombre del modelo: gemma2-2b-it-1234567891234
        • Nombre del extremo: gemma2-2b-it-mg-one-click-deploy

        Toma nota del nombre del extremo. Lo necesitarás para encontrar el ID del extremo que se usa en las muestras de código.

    3. En la sección Configuración de la implementación, haz lo siguiente:

      1. Acepta la opción predeterminada para la configuración Básica.

      2. En Región, acepta el valor predeterminado o elige una región de la lista. Anota la región. La necesitarás para los ejemplos de código.

      3. En Especificación de la máquina, elige la instancia respaldada por TPU: ct5lp-hightpu-1t (1 TPU_V5_LITEPOD; ct5lp-hightpu-1t).

  5. Haz clic en Implementar. Cuando finalice la implementación, recibirás un correo electrónico con detalles sobre tu nuevo extremo. También puedes ver los detalles del extremo haciendo clic en Predicción en línea > Endpoints y seleccionando tu región.

    Ir a Endpoints

Inferencia de Gemma 2B con PredictionServiceClient

Después de implementar Gemma 2B, usas PredictionServiceClient para obtener predicciones en línea para la instrucción: "¿Por qué el cielo es azul?".

Parámetros de código

En las muestras de código de PredictionServiceClient, debes actualizar lo siguiente.

  • PROJECT_ID: Para encontrar el ID de tu proyecto, sigue estos pasos.

    1. Ve a la página Bienvenida en la consola de Google Cloud .

      Ir a Bienvenida

    2. En el selector de proyectos que se encuentra en la parte superior de la página, selecciona tu proyecto.

      El nombre, el número y el ID del proyecto aparecen después del encabezado Bienvenido.

  • ENDPOINT_REGION: Es la región en la que implementaste el extremo.

  • ENDPOINT_ID: Para encontrar el ID de tu extremo, míralo en la consola o ejecuta el comando gcloud ai endpoints list. Necesitarás el nombre y la región del extremo del panel Implementar modelo.

    Console

    Para ver los detalles del extremo, haz clic en Predicción en línea > Endpoints y selecciona tu región. Toma nota del número que aparece en la columna ID.

    Ir a Endpoints

    gcloud

    Puedes ver los detalles del extremo ejecutando el comando gcloud ai endpoints list.

    gcloud ai endpoints list \
      --region=ENDPOINT_REGION \
      --filter=display_name=ENDPOINT_NAME
    

    El resultado se verá así:

    Using endpoint [https://us-central1-aiplatform.googleapis.com/]
    ENDPOINT_ID: 1234567891234567891
    DISPLAY_NAME: gemma2-2b-it-mg-one-click-deploy
    

Código de muestra

En el código de muestra de tu lenguaje, actualiza PROJECT_ID, ENDPOINT_REGION y ENDPOINT_ID. Luego, ejecuta tu código.

Python

Si deseas obtener información para instalar o actualizar el SDK de Vertex AI para Python, consulta Instala el SDK de Vertex AI para Python. Para obtener más información, consulta la documentación de referencia de la API de Python.

"""
Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for TPUs
# Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
input = {"prompt": prompt}
input.update(config)

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
    client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
    f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
    endpoint=gemma2_end_point,
    instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

Node.js

Antes de probar este ejemplo, sigue las instrucciones de configuración para Node.js incluidas en la guía de inicio rápido de Vertex AI sobre cómo usar bibliotecas cliente. Para obtener más información, consulta la documentación de referencia de la API de Vertex AI Node.js.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Si deseas obtener más información, consulta Configura la autenticación para un entorno de desarrollo local.

// Imports the Google Cloud Prediction Service Client library
const {
  // TODO(developer): Uncomment PredictionServiceClient before running the sample.
  // PredictionServiceClient,
  helpers,
} = require('@google-cloud/aiplatform');
/**
 * TODO(developer): Update these variables before running the sample.
 */
const projectId = 'your-project-id';
const endpointRegion = 'your-vertex-endpoint-region';
const endpointId = 'your-vertex-endpoint-id';

// Prompt used in the prediction
const prompt = 'Why is the sky blue?';

// Encapsulate the prompt in a correct format for TPUs
// Example format: [{prompt: 'Why is the sky blue?', temperature: 0.9}]
const input = {
  prompt,
  // Parameters for default configuration
  maxOutputTokens: 1024,
  temperature: 0.9,
  topP: 1.0,
  topK: 1,
};

// Convert input message to a list of GAPIC instances for model input
const instances = [helpers.toValue(input)];

// TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
// const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;

// Create a client
// predictionServiceClient = new PredictionServiceClient({apiEndpoint});

// Call the Gemma2 endpoint
const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;

const [response] = await predictionServiceClient.predict({
  endpoint: gemma2Endpoint,
  instances,
});

const predictions = response.predictions;
const text = predictions[0].stringValue;

console.log('Predictions:', text);

Java

Antes de probar este ejemplo, sigue las instrucciones de configuración para Java incluidas en la guía de inicio rápido de Vertex AI sobre cómo usar bibliotecas cliente. Para obtener más información, consulta la documentación de referencia de la API de Vertex AI Java.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Si deseas obtener más información, consulta Configura la autenticación para un entorno de desarrollo local.


import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictTpu {
  private final PredictionServiceClient predictionServiceClient;

  // Constructor to inject the PredictionServiceClient
  public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
    this.predictionServiceClient = predictionServiceClient;
  }

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String endpointRegion = "us-west1";
    String endpointId = "YOUR_ENDPOINT_ID";

    PredictionServiceSettings predictionServiceSettings =
        PredictionServiceSettings.newBuilder()
            .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
            .build();
    PredictionServiceClient predictionServiceClient =
        PredictionServiceClient.create(predictionServiceSettings);
    Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);

    creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
  }

  // Demonstrates how to run inference on a Gemma2 model
  // deployed to a Vertex AI endpoint with TPU accelerators.
  public String gemma2PredictTpu(String projectId, String region,
           String endpointId) throws IOException {
    Map<String, Object> paramsMap = new HashMap<>();
    paramsMap.put("temperature", 0.9);
    paramsMap.put("maxOutputTokens", 1024);
    paramsMap.put("topP", 1.0);
    paramsMap.put("topK", 1);
    Value parameters = mapToValue(paramsMap);
    // Prompt used in the prediction
    String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
    Value.Builder instanceValue = Value.newBuilder();
    JsonFormat.parser().merge(instance, instanceValue);
    // Encapsulate the prompt in a correct format for TPUs
    // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    List<Value> instances = new ArrayList<>();
    instances.add(instanceValue.build());

    EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

    PredictResponse predictResponse = this.predictionServiceClient
        .predict(endpointName, instances, parameters);
    String textResponse = predictResponse.getPredictions(0).getStringValue();
    System.out.println(textResponse);
    return textResponse;
  }

  private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
    Gson gson = new Gson();
    String json = gson.toJson(map);
    Value.Builder builder = Value.newBuilder();
    JsonFormat.parser().merge(json, builder);
    return builder.build();
  }
}

Go

Antes de probar este ejemplo, sigue las instrucciones de configuración para Go incluidas en la guía de inicio rápido de Vertex AI sobre cómo usar bibliotecas cliente. Para obtener más información, consulta la documentación de referencia de la API de Vertex AI Go.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Si deseas obtener más información, consulta Configura la autenticación para un entorno de desarrollo local.

import (
	"context"
	"fmt"
	"io"

	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

	"google.golang.org/protobuf/types/known/structpb"
)

// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
	ctx := context.Background()

	// Note: client can be initialized in the following way:
	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
	// if err != nil {
	// 	return fmt.Errorf("unable to create prediction client: %v", err)
	// }
	// defer client.Close()

	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
	prompt := "Why is the sky blue?"
	parameters := map[string]interface{}{
		"temperature":     0.9,
		"maxOutputTokens": 1024,
		"topP":            1.0,
		"topK":            1,
	}

	// Encapsulate the prompt in a correct format for TPUs.
	// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
	promptValue, err := structpb.NewValue(map[string]interface{}{
		"prompt":     prompt,
		"parameters": parameters,
	})
	if err != nil {
		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
		return err
	}

	req := &aiplatformpb.PredictRequest{
		Endpoint:  gemma2Endpoint,
		Instances: []*structpb.Value{promptValue},
	}

	resp, err := client.Predict(ctx, req)
	if err != nil {
		return err
	}

	prediction := resp.GetPredictions()
	value := prediction[0].GetStringValue()
	fmt.Fprintf(w, "%v", value)

	return nil
}