使用 Model Garden 和 Vertex AI TPU 支援的端點,部署並推論 Gemma

在本教學課程中,您會使用 Model Garden 將開放式模型 Gemma 2B 部署至採用 TPU 的 Vertex AI 端點。您必須先將模型部署至端點,才能使用模型進行線上預測。部署過程中,系統會將實體資源與模型建立關聯,讓模型以低延遲的方式提供線上預測結果。

部署 Gemma 2B 模型後,您可以使用 PredictionServiceClient 推論訓練好的模型,以取得線上預測。線上預測是對部署至端點的模型發出的同步要求。

使用 Model Garden 部署 Gemma

您將 Gemma 2B 模型部署至專為中小規模訓練最佳化的 ct5lp-hightpu-1t Compute Engine 機器類型。這部機器有一個 TPU v5e 加速器。如要進一步瞭解如何使用 TPU 訓練模型,請參閱 Cloud TPU v5e 訓練

在本教學課程中,您會使用 Model Garden 中的模型資訊卡,部署經過指令微調的 Gemma 2B 開放式模型。具體模型版本為 gemma2-2b-it,其中 -it 代表指令調整

Gemma 2B 模型參數較少,因此資源需求較低,部署彈性也較高。

  1. 前往 Google Cloud 控制台的「Model Garden」頁面。

    前往 Model Garden

  2. 點按「Gemma 2」Gemma 2模型資訊卡。

    前往 Gemma 2

  3. 按一下「Deploy」(部署),開啟「Deploy model」(部署模型)窗格。

  4. 在「部署模型」窗格中,指定下列詳細資料。

    1. 在「部署環境」中,按一下「Vertex AI」

    2. 在「Deploy model」(部署模型) 部分:

      1. 在「Resource ID」(資源 ID) 中選擇 gemma-2b-it

      2. 接受「模型名稱」和「端點名稱」的預設值。例如:

        • 模型名稱:gemma2-2b-it-1234567891234
        • 端點名稱:gemma2-2b-it-mg-one-click-deploy

        請記下端點名稱,您需要這個 ID,才能找出程式碼範例中使用的端點 ID。

    3. 在「部署作業設定」部分:

      1. 接受「基本」設定的預設選項。

      2. 針對「Region」(區域),請接受預設值,或是從清單中選擇區域。請記下這個區域。程式碼範例會用到這項資訊。

      3. 在「Machine spec」(機器規格) 中,選擇以 TPU 為後端的執行個體: ct5lp-hightpu-1t (1 TPU_V5_LITEPOD; ct5lp-hightpu-1t)

  5. 按一下「Deploy」(部署)。部署完成後,您會收到一封電子郵件,內含新端點的詳細資料。您也可以依序點選「線上預測」>「端點」,然後選取所在區域,查看端點詳細資料。

    前往 Endpoints

使用 PredictionServiceClient 推論 Gemma 2B

部署 Gemma 2B 後,您可以使用 PredictionServiceClient,針對「為什麼天空是藍的?」提示取得線上預測

程式碼參數

您必須更新 PredictionServiceClient 程式碼範例中的下列項目。

  • PROJECT_ID:如要找出專案 ID,請按照下列步驟操作。

    1. 前往 Google Cloud 控制台的「Welcome」(歡迎使用) 頁面。

      前往「歡迎」

    2. 在頁面頂端的專案挑選器中選取專案。

      專案名稱、專案編號和專案 ID 會顯示在「歡迎」標題下方。

  • ENDPOINT_REGION:這是部署端點的區域。

  • ENDPOINT_ID:如要找出端點 ID,請在控制台中查看,或執行 gcloud ai endpoints list 指令。您需要「Deploy model」(部署模型) 窗格中的端點名稱和區域。

    控制台

    如要查看端點詳細資料,請依序點選「線上預測」>「端點」,然後選取您的區域。請記下「ID」欄中顯示的號碼。

    前往 Endpoints

    gcloud

    您可以執行 gcloud ai endpoints list 指令,查看端點詳細資料。

    gcloud ai endpoints list \
      --region=ENDPOINT_REGION \
      --filter=display_name=ENDPOINT_NAME
    

    輸出內容如下所示。

    Using endpoint [https://us-central1-aiplatform.googleapis.com/]
    ENDPOINT_ID: 1234567891234567891
    DISPLAY_NAME: gemma2-2b-it-mg-one-click-deploy
    

程式碼範例

在您所用語言的範例程式碼中,更新 PROJECT_IDENDPOINT_REGIONENDPOINT_ID。然後執行程式碼。

Python

如要瞭解如何安裝或更新 Python 適用的 Vertex AI SDK,請參閱「安裝 Python 適用的 Vertex AI SDK」。 詳情請參閱 Python API 參考說明文件

"""
Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for TPUs
# Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
input = {"prompt": prompt}
input.update(config)

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
    client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
    f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
    endpoint=gemma2_end_point,
    instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

Node.js

在試用這個範例之前,請先按照Node.js使用用戶端程式庫的 Vertex AI 快速入門中的操作說明進行設定。 詳情請參閱 Vertex AI Node.js API 參考說明文件

如要向 Vertex AI 進行驗證,請設定應用程式預設憑證。 詳情請參閱「為本機開發環境設定驗證」。

// Imports the Google Cloud Prediction Service Client library
const {
  // TODO(developer): Uncomment PredictionServiceClient before running the sample.
  // PredictionServiceClient,
  helpers,
} = require('@google-cloud/aiplatform');
/**
 * TODO(developer): Update these variables before running the sample.
 */
const projectId = 'your-project-id';
const endpointRegion = 'your-vertex-endpoint-region';
const endpointId = 'your-vertex-endpoint-id';

// Prompt used in the prediction
const prompt = 'Why is the sky blue?';

// Encapsulate the prompt in a correct format for TPUs
// Example format: [{prompt: 'Why is the sky blue?', temperature: 0.9}]
const input = {
  prompt,
  // Parameters for default configuration
  maxOutputTokens: 1024,
  temperature: 0.9,
  topP: 1.0,
  topK: 1,
};

// Convert input message to a list of GAPIC instances for model input
const instances = [helpers.toValue(input)];

// TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
// const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;

// Create a client
// predictionServiceClient = new PredictionServiceClient({apiEndpoint});

// Call the Gemma2 endpoint
const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;

const [response] = await predictionServiceClient.predict({
  endpoint: gemma2Endpoint,
  instances,
});

const predictions = response.predictions;
const text = predictions[0].stringValue;

console.log('Predictions:', text);

Java

在試用這個範例之前,請先按照Java使用用戶端程式庫的 Vertex AI 快速入門中的操作說明進行設定。 詳情請參閱 Vertex AI Java API 參考說明文件

如要向 Vertex AI 進行驗證,請設定應用程式預設憑證。 詳情請參閱「為本機開發環境設定驗證」。


import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictTpu {
  private final PredictionServiceClient predictionServiceClient;

  // Constructor to inject the PredictionServiceClient
  public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
    this.predictionServiceClient = predictionServiceClient;
  }

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String endpointRegion = "us-west1";
    String endpointId = "YOUR_ENDPOINT_ID";

    PredictionServiceSettings predictionServiceSettings =
        PredictionServiceSettings.newBuilder()
            .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
            .build();
    PredictionServiceClient predictionServiceClient =
        PredictionServiceClient.create(predictionServiceSettings);
    Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);

    creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
  }

  // Demonstrates how to run inference on a Gemma2 model
  // deployed to a Vertex AI endpoint with TPU accelerators.
  public String gemma2PredictTpu(String projectId, String region,
           String endpointId) throws IOException {
    Map<String, Object> paramsMap = new HashMap<>();
    paramsMap.put("temperature", 0.9);
    paramsMap.put("maxOutputTokens", 1024);
    paramsMap.put("topP", 1.0);
    paramsMap.put("topK", 1);
    Value parameters = mapToValue(paramsMap);
    // Prompt used in the prediction
    String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
    Value.Builder instanceValue = Value.newBuilder();
    JsonFormat.parser().merge(instance, instanceValue);
    // Encapsulate the prompt in a correct format for TPUs
    // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    List<Value> instances = new ArrayList<>();
    instances.add(instanceValue.build());

    EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

    PredictResponse predictResponse = this.predictionServiceClient
        .predict(endpointName, instances, parameters);
    String textResponse = predictResponse.getPredictions(0).getStringValue();
    System.out.println(textResponse);
    return textResponse;
  }

  private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
    Gson gson = new Gson();
    String json = gson.toJson(map);
    Value.Builder builder = Value.newBuilder();
    JsonFormat.parser().merge(json, builder);
    return builder.build();
  }
}

Go

在試用這個範例之前,請先按照Go使用用戶端程式庫的 Vertex AI 快速入門中的操作說明進行設定。 詳情請參閱 Vertex AI Go API 參考說明文件

如要向 Vertex AI 進行驗證,請設定應用程式預設憑證。 詳情請參閱「為本機開發環境設定驗證」。

import (
	"context"
	"fmt"
	"io"

	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

	"google.golang.org/protobuf/types/known/structpb"
)

// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
	ctx := context.Background()

	// Note: client can be initialized in the following way:
	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
	// if err != nil {
	// 	return fmt.Errorf("unable to create prediction client: %v", err)
	// }
	// defer client.Close()

	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
	prompt := "Why is the sky blue?"
	parameters := map[string]interface{}{
		"temperature":     0.9,
		"maxOutputTokens": 1024,
		"topP":            1.0,
		"topK":            1,
	}

	// Encapsulate the prompt in a correct format for TPUs.
	// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
	promptValue, err := structpb.NewValue(map[string]interface{}{
		"prompt":     prompt,
		"parameters": parameters,
	})
	if err != nil {
		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
		return err
	}

	req := &aiplatformpb.PredictRequest{
		Endpoint:  gemma2Endpoint,
		Instances: []*structpb.Value{promptValue},
	}

	resp, err := client.Predict(ctx, req)
	if err != nil {
		return err
	}

	prediction := resp.GetPredictions()
	value := prediction[0].GetStringValue()
	fmt.Fprintf(w, "%v", value)

	return nil
}