Gemma mit Model Garden und Vertex AI-Endpunkten mit TPU bereitstellen und Inferenzen ausführen

In diesem Tutorial stellen Sie das offene Modell Gemma 2B mithilfe von Model Garden auf einem TPU-gestützten Vertex AI-Endpunkt bereit. Sie müssen ein Modell auf einem Endpunkt bereitstellen, bevor es für Onlinevorhersagen verwendet werden kann. Durch die Bereitstellung eines Modells werden dem Modell physische Ressourcen zugeordnet, sodass es Onlinevorhersagen mit niedriger Latenz bereitstellen kann.

Nachdem Sie das Gemma 2B-Modell bereitgestellt haben, führen Sie die Inferenz für das trainierte Modell mit PredictionServiceClient aus, um Onlinevorhersagen zu erhalten. Onlinevorhersagen sind synchrone Anfragen an ein Modell, das auf einem Endpunkt bereitgestellt wird.

Gemma mit Model Garden bereitstellen

Sie stellen das Gemma 2B-Modell auf einem ct5lp-hightpu-1t-Compute Engine-Maschinentyp bereit, der für das Training im kleinen bis mittleren Maßstab optimiert ist. Diese Maschine hat einen TPU v5e-Beschleuniger. Weitere Informationen zum Trainieren von Modellen mit TPUs finden Sie unter Cloud TPU v5e-Training.

In dieser Anleitung stellen Sie das auf Anweisungen abgestimmte offene Modell Gemma 2B mithilfe der Modellkarte in Model Garden bereit. Die spezifische Modellversion ist gemma2-2b-it – -it steht für auf Anweisungen abgestimmt.

Das Gemma 2B-Modell hat eine geringere Parametergröße, was zu geringeren Ressourcenanforderungen und mehr Flexibilität bei der Bereitstellung führt.

  1. Rufen Sie in der Google Cloud Console die Seite Model Garden auf.

    Zu Model Garden

  2. Klicken Sie auf die Modellkarte Gemma 2.

    Zu Gemma 2

  3. Klicken Sie auf Bereitstellen, um den Bereich Modell bereitstellen zu öffnen.

  4. Geben Sie im Bereich Modell bereitstellen die folgenden Details an.

    1. Klicken Sie für Bereitstellungsumgebung auf Vertex AI.

    2. Im Abschnitt Modell bereitstellen:

      1. Wählen Sie für Ressourcen-ID die Option gemma-2b-it aus.

      2. Übernehmen Sie für Modellname und Endpunktname die Standardwerte. Beispiel:

        • Modellname: gemma2-2b-it-1234567891234
        • Endpunktname: gemma2-2b-it-mg-one-click-deploy

        Notieren Sie sich den Endpunktnamen. Sie benötigen sie, um die in den Codebeispielen verwendete Endpunkt-ID zu finden.

    3. Im Abschnitt Bereitstellungseinstellungen:

      1. Übernehmen Sie die Standardoption für die Grundeinstellungen.

      2. Übernehmen Sie für Region den Standardwert oder wählen Sie eine Region aus der Liste aus. Notieren Sie sich die Region. Sie benötigen sie für die Codebeispiele.

      3. Wählen Sie für Maschinenspezifikation die TPU-basierte Instanz aus: ct5lp-hightpu-1t (1 TPU_V5_LITEPOD; ct5lp-hightpu-1t).

  5. Klicken Sie auf Bereitstellen. Nach Abschluss des Deployments erhalten Sie eine E-Mail mit Details zu Ihrem neuen Endpunkt. Sie können die Endpunktdetails auch aufrufen, indem Sie auf Onlinevorhersage > Endpunkte klicken und Ihre Region auswählen.

    Endpunkte aufrufen

Inferenz mit Gemma 2B mit dem PredictionServiceClient

Nachdem Sie Gemma 2B bereitgestellt haben, verwenden Sie die PredictionServiceClient, um Onlinevorhersagen für den Prompt „Warum ist der Himmel blau?“ zu erhalten.

Codeparameter

Für die PredictionServiceClient-Codebeispiele müssen Sie Folgendes aktualisieren.

  • PROJECT_ID: So finden Sie Ihre Projekt-ID.

    1. Rufen Sie in der Google Cloud Console die Seite Willkommen auf.

      Zur Begrüßungsseite

    2. Wählen Sie oben auf der Seite in der Projektauswahl Ihr Projekt aus.

      Der Projektname, die Projektnummer und die Projekt-ID werden nach der Überschrift Willkommen angezeigt.

  • ENDPOINT_REGION: Die Region, in der Sie den Endpunkt bereitgestellt haben.

  • ENDPOINT_ID: Die Endpunkt-ID finden Sie in der Konsole oder indem Sie den Befehl gcloud ai endpoints list ausführen. Sie benötigen den Endpunktnamen und die Region aus dem Bereich Modell bereitstellen.

    Console

    Sie können die Endpunktdetails aufrufen, indem Sie auf Onlinevorhersage > Endpunkte klicken und Ihre Region auswählen. Notieren Sie sich die Zahl, die in der Spalte ID angezeigt wird.

    Endpunkte aufrufen

    gcloud

    Mit dem Befehl gcloud ai endpoints list können Sie die Endpunktdetails aufrufen.

    gcloud ai endpoints list \
      --region=ENDPOINT_REGION \
      --filter=display_name=ENDPOINT_NAME
    

    Die Ausgabe sieht so aus:

    Using endpoint [https://us-central1-aiplatform.googleapis.com/]
    ENDPOINT_ID: 1234567891234567891
    DISPLAY_NAME: gemma2-2b-it-mg-one-click-deploy
    

Beispielcode

Aktualisieren Sie im Beispielcode für Ihre Sprache PROJECT_ID, ENDPOINT_REGION und ENDPOINT_ID. Führen Sie dann Ihren Code aus.

Python

Informationen zur Installation des Vertex AI SDK for Python finden Sie unter Vertex AI SDK for Python installieren. Weitere Informationen finden Sie in der Python-API-Referenzdokumentation.

"""
Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for TPUs
# Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
input = {"prompt": prompt}
input.update(config)

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
    client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
    f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
    endpoint=gemma2_end_point,
    instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

Node.js

Bevor Sie dieses Beispiel anwenden, folgen Sie den Node.js-Einrichtungsschritten in der Vertex AI-Kurzanleitung zur Verwendung von Clientbibliotheken. Weitere Informationen finden Sie in der Referenzdokumentation zur Vertex AI Node.js API.

Richten Sie zur Authentifizierung bei Vertex AI Standardanmeldedaten für Anwendungen ein. Weitere Informationen finden Sie unter Authentifizierung für eine lokale Entwicklungsumgebung einrichten.

// Imports the Google Cloud Prediction Service Client library
const {
  // TODO(developer): Uncomment PredictionServiceClient before running the sample.
  // PredictionServiceClient,
  helpers,
} = require('@google-cloud/aiplatform');
/**
 * TODO(developer): Update these variables before running the sample.
 */
const projectId = 'your-project-id';
const endpointRegion = 'your-vertex-endpoint-region';
const endpointId = 'your-vertex-endpoint-id';

// Prompt used in the prediction
const prompt = 'Why is the sky blue?';

// Encapsulate the prompt in a correct format for TPUs
// Example format: [{prompt: 'Why is the sky blue?', temperature: 0.9}]
const input = {
  prompt,
  // Parameters for default configuration
  maxOutputTokens: 1024,
  temperature: 0.9,
  topP: 1.0,
  topK: 1,
};

// Convert input message to a list of GAPIC instances for model input
const instances = [helpers.toValue(input)];

// TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
// const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;

// Create a client
// predictionServiceClient = new PredictionServiceClient({apiEndpoint});

// Call the Gemma2 endpoint
const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;

const [response] = await predictionServiceClient.predict({
  endpoint: gemma2Endpoint,
  instances,
});

const predictions = response.predictions;
const text = predictions[0].stringValue;

console.log('Predictions:', text);

Java

Bevor Sie dieses Beispiel anwenden, folgen Sie den Java-Einrichtungsschritten in der Vertex AI-Kurzanleitung zur Verwendung von Clientbibliotheken. Weitere Informationen finden Sie in der Referenzdokumentation zur Vertex AI Java API.

Richten Sie zur Authentifizierung bei Vertex AI Standardanmeldedaten für Anwendungen ein. Weitere Informationen finden Sie unter Authentifizierung für eine lokale Entwicklungsumgebung einrichten.


import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictTpu {
  private final PredictionServiceClient predictionServiceClient;

  // Constructor to inject the PredictionServiceClient
  public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
    this.predictionServiceClient = predictionServiceClient;
  }

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String endpointRegion = "us-west1";
    String endpointId = "YOUR_ENDPOINT_ID";

    PredictionServiceSettings predictionServiceSettings =
        PredictionServiceSettings.newBuilder()
            .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
            .build();
    PredictionServiceClient predictionServiceClient =
        PredictionServiceClient.create(predictionServiceSettings);
    Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);

    creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
  }

  // Demonstrates how to run inference on a Gemma2 model
  // deployed to a Vertex AI endpoint with TPU accelerators.
  public String gemma2PredictTpu(String projectId, String region,
           String endpointId) throws IOException {
    Map<String, Object> paramsMap = new HashMap<>();
    paramsMap.put("temperature", 0.9);
    paramsMap.put("maxOutputTokens", 1024);
    paramsMap.put("topP", 1.0);
    paramsMap.put("topK", 1);
    Value parameters = mapToValue(paramsMap);
    // Prompt used in the prediction
    String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
    Value.Builder instanceValue = Value.newBuilder();
    JsonFormat.parser().merge(instance, instanceValue);
    // Encapsulate the prompt in a correct format for TPUs
    // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    List<Value> instances = new ArrayList<>();
    instances.add(instanceValue.build());

    EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

    PredictResponse predictResponse = this.predictionServiceClient
        .predict(endpointName, instances, parameters);
    String textResponse = predictResponse.getPredictions(0).getStringValue();
    System.out.println(textResponse);
    return textResponse;
  }

  private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
    Gson gson = new Gson();
    String json = gson.toJson(map);
    Value.Builder builder = Value.newBuilder();
    JsonFormat.parser().merge(json, builder);
    return builder.build();
  }
}

Go

Bevor Sie dieses Beispiel anwenden, folgen Sie den Go-Einrichtungsschritten in der Vertex AI-Kurzanleitung zur Verwendung von Clientbibliotheken. Weitere Informationen finden Sie in der Referenzdokumentation zur Vertex AI Go API.

Richten Sie zur Authentifizierung bei Vertex AI Standardanmeldedaten für Anwendungen ein. Weitere Informationen finden Sie unter Authentifizierung für eine lokale Entwicklungsumgebung einrichten.

import (
	"context"
	"fmt"
	"io"

	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

	"google.golang.org/protobuf/types/known/structpb"
)

// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
	ctx := context.Background()

	// Note: client can be initialized in the following way:
	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
	// if err != nil {
	// 	return fmt.Errorf("unable to create prediction client: %v", err)
	// }
	// defer client.Close()

	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
	prompt := "Why is the sky blue?"
	parameters := map[string]interface{}{
		"temperature":     0.9,
		"maxOutputTokens": 1024,
		"topP":            1.0,
		"topK":            1,
	}

	// Encapsulate the prompt in a correct format for TPUs.
	// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
	promptValue, err := structpb.NewValue(map[string]interface{}{
		"prompt":     prompt,
		"parameters": parameters,
	})
	if err != nil {
		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
		return err
	}

	req := &aiplatformpb.PredictRequest{
		Endpoint:  gemma2Endpoint,
		Instances: []*structpb.Value{promptValue},
	}

	resp, err := client.Predict(ctx, req)
	if err != nil {
		return err
	}

	prediction := resp.GetPredictions()
	value := prediction[0].GetStringValue()
	fmt.Fprintf(w, "%v", value)

	return nil
}