Esegui il deployment e l'inferenza di Gemma utilizzando Model Garden e gli endpoint basati su TPU di Vertex AI

In questo tutorial viene utilizzato Model Garden per eseguire il deployment del modello aperto Gemma 2B in un endpoint Vertex AI supportato da TPU. Devi eseguire il deployment di un modello in un endpoint prima di poterlo utilizzare per fornire previsioni online. Il deployment di un modello associa risorse fisiche al modello in modo che possa fornire previsioni online con bassa latenza.

Dopo aver eseguito il deployment del modello Gemma 2B, esegui l'inferenza del modello addestrato utilizzando PredictionServiceClient per ottenere previsioni online. Le previsioni online sono richieste sincrone effettuate a un modello di cui è stato eseguito il deployment in un endpoint.

Esegui il deployment di Gemma utilizzando Model Garden

Esegui il deployment del modello Gemma 2B su un tipo di macchina Compute Engine ct5lp-hightpu-1t ottimizzato per l'addestramento su piccola e media scala. Questa macchina ha un acceleratore TPU v5e. Per saperne di più sull'addestramento dei modelli utilizzando le TPU, consulta la sezione Addestramento di Cloud TPU v5e.

In questo tutorial, esegui il deployment del modello aperto Gemma 2B ottimizzato per le istruzioni utilizzando la scheda del modello in Model Garden. La versione specifica del modello è gemma2-2b-it. -it sta per ottimizzato per le istruzioni.

Il modello Gemma 2B ha dimensioni dei parametri inferiori, il che significa requisiti di risorse inferiori e maggiore flessibilità di implementazione.

  1. Nella console Google Cloud , vai alla pagina Model Garden.

    Vai a Model Garden

  2. Fai clic sulla scheda del modello Gemma 2.

    Vai a Gemma 2

  3. Fai clic su Esegui il deployment per aprire il riquadro Esegui il deployment del modello.

  4. Nel riquadro Esegui il deployment del modello, specifica questi dettagli.

    1. Per Deployment environment (Ambiente di deployment), fai clic su Vertex AI.

    2. Nella sezione Esegui il deployment del modello:

      1. In ID risorsa, scegli gemma-2b-it.

      2. Per Nome modello e Nome endpoint, accetta i valori predefiniti. Ad esempio:

        • Nome modello: gemma2-2b-it-1234567891234
        • Nome endpoint: gemma2-2b-it-mg-one-click-deploy

        Prendi nota del nome dell'endpoint. Ti servirà per trovare l'ID endpoint utilizzato negli esempi di codice.

    3. Nella sezione Impostazioni di deployment:

      1. Accetta l'opzione predefinita per le impostazioni di base.

      2. Per Regione, accetta il valore predefinito o scegli una regione dall'elenco. Prendi nota della regione. Ti servirà per gli esempi di codice.

      3. Per Specifica macchina, scegli l'istanza supportata dalla TPU: ct5lp-hightpu-1t (1 TPU_V5_LITEPOD; ct5lp-hightpu-1t).

  5. Fai clic su Esegui il deployment. Al termine del deployment, ricevi un'email contenente i dettagli del nuovo endpoint. Puoi anche visualizzare i dettagli dell'endpoint facendo clic su Previsione online > Endpoint e selezionando la tua regione.

    Vai a Endpoint

Inferenza di Gemma 2B con PredictionServiceClient

Dopo aver eseguito il deployment di Gemma 2B, utilizzi PredictionServiceClient per ottenere previsioni online per il prompt: "Perché il cielo è blu?"

Parametri di codice

Gli esempi di codice PredictionServiceClient richiedono di aggiornare quanto segue.

  • PROJECT_ID: Per trovare l'ID progetto, segui questi passaggi.

    1. Vai alla pagina Benvenuto nella console Google Cloud .

      Vai a Benvenuto

    2. Nel selettore di progetti nella parte superiore della pagina, seleziona il tuo progetto.

      Il nome, il numero e l'ID progetto vengono visualizzati dopo l'intestazione Benvenuto.

  • ENDPOINT_REGION: la regione in cui hai implementato l'endpoint.

  • ENDPOINT_ID: per trovare l'ID endpoint, visualizzalo nella console o esegui il comando gcloud ai endpoints list. Avrai bisogno del nome e della regione dell'endpoint dal riquadro Esegui il deployment del modello.

    Console

    Puoi visualizzare i dettagli dell'endpoint facendo clic su Online prediction > Endpoints e selezionando la tua regione. Prendi nota del numero visualizzato nella colonna ID.

    Vai a Endpoint

    gcloud

    Puoi visualizzare i dettagli dell'endpoint eseguendo il comando gcloud ai endpoints list.

    gcloud ai endpoints list \
      --region=ENDPOINT_REGION \
      --filter=display_name=ENDPOINT_NAME
    

    L'output è simile al seguente.

    Using endpoint [https://us-central1-aiplatform.googleapis.com/]
    ENDPOINT_ID: 1234567891234567891
    DISPLAY_NAME: gemma2-2b-it-mg-one-click-deploy
    

Codice di esempio

Nel codice campione per la tua lingua, aggiorna PROJECT_ID, ENDPOINT_REGION e ENDPOINT_ID. Quindi esegui il codice.

Python

Per scoprire come installare o aggiornare l'SDK Vertex AI Python, consulta Installare l'SDK Vertex AI Python. Per saperne di più, consulta la documentazione di riferimento dell'API Python.

"""
Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for TPUs
# Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
input = {"prompt": prompt}
input.update(config)

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
    client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
    f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
    endpoint=gemma2_end_point,
    instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

Node.js

Prima di provare questo esempio, segui le istruzioni di configurazione di Node.js nella guida rapida di Vertex AI per l'utilizzo delle librerie client. Per saperne di più, consulta la documentazione di riferimento dell'API Vertex AI Node.js.

Per eseguire l'autenticazione in Vertex AI, configura le Credenziali predefinite dell'applicazione. Per ulteriori informazioni, consulta Configura l'autenticazione per un ambiente di sviluppo locale.

// Imports the Google Cloud Prediction Service Client library
const {
  // TODO(developer): Uncomment PredictionServiceClient before running the sample.
  // PredictionServiceClient,
  helpers,
} = require('@google-cloud/aiplatform');
/**
 * TODO(developer): Update these variables before running the sample.
 */
const projectId = 'your-project-id';
const endpointRegion = 'your-vertex-endpoint-region';
const endpointId = 'your-vertex-endpoint-id';

// Prompt used in the prediction
const prompt = 'Why is the sky blue?';

// Encapsulate the prompt in a correct format for TPUs
// Example format: [{prompt: 'Why is the sky blue?', temperature: 0.9}]
const input = {
  prompt,
  // Parameters for default configuration
  maxOutputTokens: 1024,
  temperature: 0.9,
  topP: 1.0,
  topK: 1,
};

// Convert input message to a list of GAPIC instances for model input
const instances = [helpers.toValue(input)];

// TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
// const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;

// Create a client
// predictionServiceClient = new PredictionServiceClient({apiEndpoint});

// Call the Gemma2 endpoint
const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;

const [response] = await predictionServiceClient.predict({
  endpoint: gemma2Endpoint,
  instances,
});

const predictions = response.predictions;
const text = predictions[0].stringValue;

console.log('Predictions:', text);

Java

Prima di provare questo esempio, segui le istruzioni di configurazione di Java nella guida rapida di Vertex AI per l'utilizzo delle librerie client. Per saperne di più, consulta la documentazione di riferimento dell'API Vertex AI Java.

Per eseguire l'autenticazione in Vertex AI, configura le Credenziali predefinite dell'applicazione. Per ulteriori informazioni, consulta Configura l'autenticazione per un ambiente di sviluppo locale.


import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictTpu {
  private final PredictionServiceClient predictionServiceClient;

  // Constructor to inject the PredictionServiceClient
  public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
    this.predictionServiceClient = predictionServiceClient;
  }

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String endpointRegion = "us-west1";
    String endpointId = "YOUR_ENDPOINT_ID";

    PredictionServiceSettings predictionServiceSettings =
        PredictionServiceSettings.newBuilder()
            .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
            .build();
    PredictionServiceClient predictionServiceClient =
        PredictionServiceClient.create(predictionServiceSettings);
    Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);

    creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
  }

  // Demonstrates how to run inference on a Gemma2 model
  // deployed to a Vertex AI endpoint with TPU accelerators.
  public String gemma2PredictTpu(String projectId, String region,
           String endpointId) throws IOException {
    Map<String, Object> paramsMap = new HashMap<>();
    paramsMap.put("temperature", 0.9);
    paramsMap.put("maxOutputTokens", 1024);
    paramsMap.put("topP", 1.0);
    paramsMap.put("topK", 1);
    Value parameters = mapToValue(paramsMap);
    // Prompt used in the prediction
    String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
    Value.Builder instanceValue = Value.newBuilder();
    JsonFormat.parser().merge(instance, instanceValue);
    // Encapsulate the prompt in a correct format for TPUs
    // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    List<Value> instances = new ArrayList<>();
    instances.add(instanceValue.build());

    EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

    PredictResponse predictResponse = this.predictionServiceClient
        .predict(endpointName, instances, parameters);
    String textResponse = predictResponse.getPredictions(0).getStringValue();
    System.out.println(textResponse);
    return textResponse;
  }

  private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
    Gson gson = new Gson();
    String json = gson.toJson(map);
    Value.Builder builder = Value.newBuilder();
    JsonFormat.parser().merge(json, builder);
    return builder.build();
  }
}

Go

Prima di provare questo esempio, segui le istruzioni di configurazione di Go nella guida rapida di Vertex AI per l'utilizzo delle librerie client. Per saperne di più, consulta la documentazione di riferimento dell'API Vertex AI Go.

Per eseguire l'autenticazione in Vertex AI, configura le Credenziali predefinite dell'applicazione. Per ulteriori informazioni, consulta Configura l'autenticazione per un ambiente di sviluppo locale.

import (
	"context"
	"fmt"
	"io"

	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

	"google.golang.org/protobuf/types/known/structpb"
)

// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
	ctx := context.Background()

	// Note: client can be initialized in the following way:
	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
	// if err != nil {
	// 	return fmt.Errorf("unable to create prediction client: %v", err)
	// }
	// defer client.Close()

	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
	prompt := "Why is the sky blue?"
	parameters := map[string]interface{}{
		"temperature":     0.9,
		"maxOutputTokens": 1024,
		"topP":            1.0,
		"topK":            1,
	}

	// Encapsulate the prompt in a correct format for TPUs.
	// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
	promptValue, err := structpb.NewValue(map[string]interface{}{
		"prompt":     prompt,
		"parameters": parameters,
	})
	if err != nil {
		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
		return err
	}

	req := &aiplatformpb.PredictRequest{
		Endpoint:  gemma2Endpoint,
		Instances: []*structpb.Value{promptValue},
	}

	resp, err := client.Predict(ctx, req)
	if err != nil {
		return err
	}

	prediction := resp.GetPredictions()
	value := prediction[0].GetStringValue()
	fmt.Fprintf(w, "%v", value)

	return nil
}