Déployer et inférer Gemma à l'aide de Model Garden et de points de terminaison Vertex AI basés sur des TPU

Dans ce tutoriel, vous utilisez Model Garden pour déployer le modèle ouvert Gemma 2B sur un point de terminaison Vertex AI basé sur TPU. Vous devez déployer un modèle sur un point de terminaison avant de pouvoir utiliser ce modèle pour livrer des prédictions en ligne. Le déploiement d'un modèle associe des ressources physiques au modèle afin qu'il puisse générer des prédictions en ligne avec une faible latence.

Après avoir déployé le modèle Gemma 2B, vous pouvez exécuter l'inférence du modèle entraîné à l'aide de PredictionServiceClient pour obtenir des prédictions en ligne. Les prédictions en ligne sont le résultat de requêtes synchrones adressées à un modèle déployé sur un point de terminaison.

Déployer Gemma à l'aide de Model Garden

Vous déployez le modèle Gemma 2B sur un type de machine ct5lp-hightpu-1t Compute Engine optimisé pour l'entraînement à petite ou moyenne échelle. Cette machine dispose d'un accélérateur TPU v5e. Pour en savoir plus sur l'entraînement de modèles à l'aide de TPU, consultez Entraînement avec Cloud TPU v5e.

Dans ce tutoriel, vous allez déployer le modèle ouvert Gemma 2B adapté aux instructions à l'aide de la fiche de modèle dans Model Garden. La version spécifique du modèle est gemma2-2b-it. -it signifie qu'il est adapté aux instructions.

Le modèle Gemma 2B présente une taille de paramètre plus basse, ce qui signifie des besoins en ressources moins élevés et plus de flexibilité de déploiement.

  1. Dans la console Google Cloud , accédez à la page Model Garden.

    Accéder à Model Garden

  2. Cliquez sur la fiche du modèle Gemma 2.

    Accéder à Gemma 2

  3. Cliquez sur Déployer pour ouvrir le volet Déployer le modèle.

  4. Dans le volet Déployer le modèle, spécifiez les informations suivantes.

    1. Dans le champ Environnement de déploiement, cliquez sur Vertex AI.

    2. Dans la section Déployer le modèle :

      1. Dans le champ ID de ressource, sélectionnez gemma-2b-it.

      2. Pour Nom du modèle et Nom du point de terminaison, acceptez les valeurs par défaut. Par exemple :

        • Nom du modèle : gemma2-2b-it-1234567891234
        • Nom du point de terminaison : gemma2-2b-it-mg-one-click-deploy

        Notez le nom du point de terminaison. Vous en aurez besoin pour trouver l'ID de point de terminaison utilisé dans les exemples de code.

    3. Dans la section Paramètres de déploiement :

      1. Acceptez l'option par défaut pour les paramètres de base.

      2. Pour Région, acceptez la valeur par défaut ou sélectionnez une région dans la liste. Notez la région. Vous en aurez besoin pour les exemples de code.

      3. Dans Spécifications de la machine, sélectionnez l'instance basée sur TPU : ct5lp-hightpu-1t (1 TPU_V5_LITEPOD; ct5lp-hightpu-1t).

  5. Cliquez sur Déployer. Une fois le déploiement terminé, vous recevez un e-mail contenant des informations sur votre nouveau point de terminaison. Vous pouvez également afficher les détails du point de terminaison en cliquant sur Prédiction en ligne > Points de terminaison et en sélectionnant votre région.

    Accéder à la page "Points de terminaison"

Effectuer une inférence du modèle Gemma 2B avec PredictionServiceClient

Après avoir déployé le modèle Gemma 2B, vous pouvez utiliser PredictionServiceClient pour obtenir des prédictions en ligne pour le prompt "Pourquoi le ciel est-il bleu ?"

Paramètres de code

Les exemples de code PredictionServiceClient nécessitent que vous mettiez à jour les éléments suivants.

  • PROJECT_ID : pour trouver l'ID de votre projet, procédez comme suit.

    1. Accédez à la page d'accueil de la console Google Cloud .

      Accéder à la page d'accueil

    2. Dans le sélecteur de projets situé en haut de la page, sélectionnez votre projet.

      Le nom ainsi que le numéro et l'ID du projet apparaissent après l'en-tête Bienvenue.

  • ENDPOINT_REGION : région dans laquelle vous avez déployé le point de terminaison.

  • ENDPOINT_ID : pour trouver l'ID de votre point de terminaison, affichez-le dans la console ou exécutez la commande gcloud ai endpoints list. Vous aurez besoin du nom et de la région du point de terminaison dans le volet Déployer le modèle.

    Console

    Vous pouvez afficher les détails du point de terminaison en cliquant sur Prédiction en ligne > Points de terminaison et en sélectionnant votre région. Notez le nombre qui s'affiche dans la colonne ID.

    Accéder à la page "Points de terminaison"

    gcloud

    Vous pouvez afficher les détails du point de terminaison en exécutant la commande gcloud ai endpoints list.

    gcloud ai endpoints list \
      --region=ENDPOINT_REGION \
      --filter=display_name=ENDPOINT_NAME
    

    Le résultat ressemble à ceci :

    Using endpoint [https://us-central1-aiplatform.googleapis.com/]
    ENDPOINT_ID: 1234567891234567891
    DISPLAY_NAME: gemma2-2b-it-mg-one-click-deploy
    

Exemple de code

Dans l'exemple de code correspondant à votre langage, modifiez PROJECT_ID, ENDPOINT_REGION et ENDPOINT_ID. Exécutez ensuite votre code.

Python

Pour savoir comment installer ou mettre à jour le SDK Vertex AI pour Python, consultez Installer le SDK Vertex AI pour Python. Pour en savoir plus, consultez la documentation de référence de l'API Python.

"""
Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for TPUs
# Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
input = {"prompt": prompt}
input.update(config)

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
    client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
    f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
    endpoint=gemma2_end_point,
    instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

Node.js

Avant d'essayer cet exemple, suivez les instructions de configuration pour Node.js décrites dans le guide de démarrage rapide de Vertex AI sur l'utilisation des bibliothèques clientes. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI pour Node.js.

Pour vous authentifier auprès de Vertex AI, configurez le service Identifiants par défaut de l'application. Pour en savoir plus, consultez Configurer l'authentification pour un environnement de développement local.

// Imports the Google Cloud Prediction Service Client library
const {
  // TODO(developer): Uncomment PredictionServiceClient before running the sample.
  // PredictionServiceClient,
  helpers,
} = require('@google-cloud/aiplatform');
/**
 * TODO(developer): Update these variables before running the sample.
 */
const projectId = 'your-project-id';
const endpointRegion = 'your-vertex-endpoint-region';
const endpointId = 'your-vertex-endpoint-id';

// Prompt used in the prediction
const prompt = 'Why is the sky blue?';

// Encapsulate the prompt in a correct format for TPUs
// Example format: [{prompt: 'Why is the sky blue?', temperature: 0.9}]
const input = {
  prompt,
  // Parameters for default configuration
  maxOutputTokens: 1024,
  temperature: 0.9,
  topP: 1.0,
  topK: 1,
};

// Convert input message to a list of GAPIC instances for model input
const instances = [helpers.toValue(input)];

// TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
// const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;

// Create a client
// predictionServiceClient = new PredictionServiceClient({apiEndpoint});

// Call the Gemma2 endpoint
const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;

const [response] = await predictionServiceClient.predict({
  endpoint: gemma2Endpoint,
  instances,
});

const predictions = response.predictions;
const text = predictions[0].stringValue;

console.log('Predictions:', text);

Java

Avant d'essayer cet exemple, suivez les instructions de configuration pour Java décrites dans le guide de démarrage rapide de Vertex AI sur l'utilisation des bibliothèques clientes. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI pour Java.

Pour vous authentifier auprès de Vertex AI, configurez le service Identifiants par défaut de l'application. Pour en savoir plus, consultez Configurer l'authentification pour un environnement de développement local.


import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictTpu {
  private final PredictionServiceClient predictionServiceClient;

  // Constructor to inject the PredictionServiceClient
  public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
    this.predictionServiceClient = predictionServiceClient;
  }

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String endpointRegion = "us-west1";
    String endpointId = "YOUR_ENDPOINT_ID";

    PredictionServiceSettings predictionServiceSettings =
        PredictionServiceSettings.newBuilder()
            .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
            .build();
    PredictionServiceClient predictionServiceClient =
        PredictionServiceClient.create(predictionServiceSettings);
    Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);

    creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
  }

  // Demonstrates how to run inference on a Gemma2 model
  // deployed to a Vertex AI endpoint with TPU accelerators.
  public String gemma2PredictTpu(String projectId, String region,
           String endpointId) throws IOException {
    Map<String, Object> paramsMap = new HashMap<>();
    paramsMap.put("temperature", 0.9);
    paramsMap.put("maxOutputTokens", 1024);
    paramsMap.put("topP", 1.0);
    paramsMap.put("topK", 1);
    Value parameters = mapToValue(paramsMap);
    // Prompt used in the prediction
    String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
    Value.Builder instanceValue = Value.newBuilder();
    JsonFormat.parser().merge(instance, instanceValue);
    // Encapsulate the prompt in a correct format for TPUs
    // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    List<Value> instances = new ArrayList<>();
    instances.add(instanceValue.build());

    EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

    PredictResponse predictResponse = this.predictionServiceClient
        .predict(endpointName, instances, parameters);
    String textResponse = predictResponse.getPredictions(0).getStringValue();
    System.out.println(textResponse);
    return textResponse;
  }

  private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
    Gson gson = new Gson();
    String json = gson.toJson(map);
    Value.Builder builder = Value.newBuilder();
    JsonFormat.parser().merge(json, builder);
    return builder.build();
  }
}

Go

Avant d'essayer cet exemple, suivez les instructions de configuration pour Go décrites dans le guide de démarrage rapide de Vertex AI sur l'utilisation des bibliothèques clientes. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI pour Go.

Pour vous authentifier auprès de Vertex AI, configurez le service Identifiants par défaut de l'application. Pour en savoir plus, consultez Configurer l'authentification pour un environnement de développement local.

import (
	"context"
	"fmt"
	"io"

	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

	"google.golang.org/protobuf/types/known/structpb"
)

// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
	ctx := context.Background()

	// Note: client can be initialized in the following way:
	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
	// if err != nil {
	// 	return fmt.Errorf("unable to create prediction client: %v", err)
	// }
	// defer client.Close()

	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
	prompt := "Why is the sky blue?"
	parameters := map[string]interface{}{
		"temperature":     0.9,
		"maxOutputTokens": 1024,
		"topP":            1.0,
		"topK":            1,
	}

	// Encapsulate the prompt in a correct format for TPUs.
	// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
	promptValue, err := structpb.NewValue(map[string]interface{}{
		"prompt":     prompt,
		"parameters": parameters,
	})
	if err != nil {
		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
		return err
	}

	req := &aiplatformpb.PredictRequest{
		Endpoint:  gemma2Endpoint,
		Instances: []*structpb.Value{promptValue},
	}

	resp, err := client.Predict(ctx, req)
	if err != nil {
		return err
	}

	prediction := resp.GetPredictions()
	value := prediction[0].GetStringValue()
	fmt.Fprintf(w, "%v", value)

	return nil
}