Model Garden と Vertex AI TPU 対応エンドポイントを使用して Gemma をデプロイして推論する

このチュートリアルでは、Model Garden を使用して、TPU をベースとする Vertex AI エンドポイントに Gemma 2B オープンモデルをデプロイします。オンライン予測用にモデルを配信する前に、モデルをエンドポイントにデプロイする必要があります。モデルのデプロイでは、少ないレイテンシでオンライン予測を提供できるように、モデルに物理リソースを関連付けます。

Gemma 2B モデルをデプロイしたら、PredictionServiceClient を使用してトレーニング済みモデルを推論し、オンライン予測を取得します。オンライン予測は、エンドポイントにデプロイされたモデルに対して行われる同期リクエストです。

Model Garden を使用して Gemma をデプロイする

Gemma 2B モデルを、小規模から中規模のトレーニング用に最適化された ct5lp-hightpu-1t Compute Engine マシンタイプにデプロイします。このマシンには 1 つの TPU v5e アクセラレータが備わっています。TPU を使用してモデルをトレーニングする方法の詳細については、Cloud TPU v5e トレーニングをご覧ください。

このチュートリアルでは、Model Garden のモデルカードを使用して、命令でチューニングされた Gemma 2B オープンモデルをデプロイします。具体的なモデル バージョンは gemma2-2b-it です。-it指示のチューニング済みを表します。

Gemma 2B モデルのパラメータ サイズは小さいため、リソース要件が少なく、デプロイの柔軟性が高まります。

  1. Google Cloud コンソールで、[Model Garden] ページに移動します。

    Model Garden に移動

  2. [Gemma 2] モデルカードをクリックします。

    Gemma 2 に移動

  3. [デプロイ] をクリックして、[モデルのデプロイ] ペインを開きます。

  4. [モデルのデプロイ] ペインで、次の詳細を指定します。

    1. [デプロイ環境] で [Vertex AI] をクリックします。

    2. [モデルをデプロイする] セクションで、次の操作を行います。

      1. [リソース ID] に gemma-2b-it を選択します。

      2. [モデル名] と [エンドポイント名] のデフォルト値を使用します。例:

        • モデル名: gemma2-2b-it-1234567891234
        • エンドポイント名: gemma2-2b-it-mg-one-click-deploy

        エンドポイント名をメモします。コードサンプルで使用されているエンドポイント ID を確認するために必要になります。

    3. [デプロイの設定] セクションで、次の操作を行います。

      1. [基本] 設定はデフォルト オプションのままにします。

      2. [リージョン] では、デフォルト値のままにするか、リストからリージョンを選択します。リージョンをメモします。コードサンプルに必要になります。

      3. [マシン仕様] で、TPU を使用するインスタンス ct5lp-hightpu-1t (1 TPU_V5_LITEPOD; ct5lp-hightpu-1t) を選択します。

  5. [デプロイ] をクリックします。デプロイが完了すると、新しいエンドポイントの詳細が記載されたメールが届きます。[オンライン予測] > [エンドポイント] をクリックしてリージョンを選択して、エンドポイントの詳細を確認することもできます。

    エンドポイントに移動

PredictionServiceClient で Gemma 2B を推論する

Gemma 2B をデプロイしたら、PredictionServiceClient を使用して「空はなぜ青いの?」というプロンプトのオンライン予測を取得します。

コード パラメータ

PredictionServiceClient コードサンプルでは、以下を更新する必要があります。

  • PROJECT_ID: プロジェクト ID を確認する手順は次のとおりです。

    1. Google Cloud コンソールの [ようこそ] ページに移動します。

      [ようこそ] に移動

    2. ページ上部のプロジェクト選択ツールで、自分のプロジェクトを選択します。

      プロジェクト名、プロジェクト番号、プロジェクト ID は [ようこそ] の見出しの後に表示されます。

  • ENDPOINT_REGION: エンドポイントをデプロイしたリージョンです。

  • ENDPOINT_ID: エンドポイント ID を確認するには、コンソールで確認するか、gcloud ai endpoints list コマンドを実行します。[モデルをデプロイ] ペインからエンドポイント名とリージョンを取得します。

    コンソール

    [オンライン予測] > [エンドポイント] をクリックしてリージョンを選択すると、エンドポイントの詳細を確認できます。ID 列に表示される番号をメモします。

    エンドポイントに移動

    gcloud

    エンドポイントの詳細を表示するには、gcloud ai endpoints list コマンドを実行します。

    gcloud ai endpoints list \
      --region=ENDPOINT_REGION \
      --filter=display_name=ENDPOINT_NAME
    

    出力は次のようになります。

    Using endpoint [https://us-central1-aiplatform.googleapis.com/]
    ENDPOINT_ID: 1234567891234567891
    DISPLAY_NAME: gemma2-2b-it-mg-one-click-deploy
    

サンプルコード

言語のサンプルコードで、PROJECT_IDENDPOINT_REGIONENDPOINT_ID を更新します。次に、コードを実行します。

Python

Vertex AI SDK for Python のインストールまたは更新の方法については、Vertex AI SDK for Python をインストールするをご覧ください。 詳細については、Python API リファレンス ドキュメントをご覧ください。

"""
Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for TPUs
# Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
input = {"prompt": prompt}
input.update(config)

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
    client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
    f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
    endpoint=gemma2_end_point,
    instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

Node.js

このサンプルを試す前に、Vertex AI クイックスタート: クライアント ライブラリの使用にある Node.js の設定手順を完了してください。詳細については、Vertex AI Node.js API のリファレンス ドキュメントをご覧ください。

Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。

// Imports the Google Cloud Prediction Service Client library
const {
  // TODO(developer): Uncomment PredictionServiceClient before running the sample.
  // PredictionServiceClient,
  helpers,
} = require('@google-cloud/aiplatform');
/**
 * TODO(developer): Update these variables before running the sample.
 */
const projectId = 'your-project-id';
const endpointRegion = 'your-vertex-endpoint-region';
const endpointId = 'your-vertex-endpoint-id';

// Prompt used in the prediction
const prompt = 'Why is the sky blue?';

// Encapsulate the prompt in a correct format for TPUs
// Example format: [{prompt: 'Why is the sky blue?', temperature: 0.9}]
const input = {
  prompt,
  // Parameters for default configuration
  maxOutputTokens: 1024,
  temperature: 0.9,
  topP: 1.0,
  topK: 1,
};

// Convert input message to a list of GAPIC instances for model input
const instances = [helpers.toValue(input)];

// TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
// const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;

// Create a client
// predictionServiceClient = new PredictionServiceClient({apiEndpoint});

// Call the Gemma2 endpoint
const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;

const [response] = await predictionServiceClient.predict({
  endpoint: gemma2Endpoint,
  instances,
});

const predictions = response.predictions;
const text = predictions[0].stringValue;

console.log('Predictions:', text);

Java

このサンプルを試す前に、Vertex AI クイックスタート: クライアント ライブラリの使用にある Java の設定手順を完了してください。詳細については、Vertex AI Java API のリファレンス ドキュメントをご覧ください。

Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。


import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.gson.Gson;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Gemma2PredictTpu {
  private final PredictionServiceClient predictionServiceClient;

  // Constructor to inject the PredictionServiceClient
  public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
    this.predictionServiceClient = predictionServiceClient;
  }

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String endpointRegion = "us-west1";
    String endpointId = "YOUR_ENDPOINT_ID";

    PredictionServiceSettings predictionServiceSettings =
        PredictionServiceSettings.newBuilder()
            .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
            .build();
    PredictionServiceClient predictionServiceClient =
        PredictionServiceClient.create(predictionServiceSettings);
    Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);

    creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
  }

  // Demonstrates how to run inference on a Gemma2 model
  // deployed to a Vertex AI endpoint with TPU accelerators.
  public String gemma2PredictTpu(String projectId, String region,
           String endpointId) throws IOException {
    Map<String, Object> paramsMap = new HashMap<>();
    paramsMap.put("temperature", 0.9);
    paramsMap.put("maxOutputTokens", 1024);
    paramsMap.put("topP", 1.0);
    paramsMap.put("topK", 1);
    Value parameters = mapToValue(paramsMap);
    // Prompt used in the prediction
    String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
    Value.Builder instanceValue = Value.newBuilder();
    JsonFormat.parser().merge(instance, instanceValue);
    // Encapsulate the prompt in a correct format for TPUs
    // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    List<Value> instances = new ArrayList<>();
    instances.add(instanceValue.build());

    EndpointName endpointName = EndpointName.of(projectId, region, endpointId);

    PredictResponse predictResponse = this.predictionServiceClient
        .predict(endpointName, instances, parameters);
    String textResponse = predictResponse.getPredictions(0).getStringValue();
    System.out.println(textResponse);
    return textResponse;
  }

  private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
    Gson gson = new Gson();
    String json = gson.toJson(map);
    Value.Builder builder = Value.newBuilder();
    JsonFormat.parser().merge(json, builder);
    return builder.build();
  }
}

Go

このサンプルを試す前に、Vertex AI クイックスタート: クライアント ライブラリの使用にある Go の設定手順を完了してください。詳細については、Vertex AI Go API のリファレンス ドキュメントをご覧ください。

Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。

import (
	"context"
	"fmt"
	"io"

	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"

	"google.golang.org/protobuf/types/known/structpb"
)

// predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
	ctx := context.Background()

	// Note: client can be initialized in the following way:
	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
	// if err != nil {
	// 	return fmt.Errorf("unable to create prediction client: %v", err)
	// }
	// defer client.Close()

	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
	prompt := "Why is the sky blue?"
	parameters := map[string]interface{}{
		"temperature":     0.9,
		"maxOutputTokens": 1024,
		"topP":            1.0,
		"topK":            1,
	}

	// Encapsulate the prompt in a correct format for TPUs.
	// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
	promptValue, err := structpb.NewValue(map[string]interface{}{
		"prompt":     prompt,
		"parameters": parameters,
	})
	if err != nil {
		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
		return err
	}

	req := &aiplatformpb.PredictRequest{
		Endpoint:  gemma2Endpoint,
		Instances: []*structpb.Value{promptValue},
	}

	resp, err := client.Predict(ctx, req)
	if err != nil {
		return err
	}

	prediction := resp.GetPredictions()
	value := prediction[0].GetStringValue()
	fmt.Fprintf(w, "%v", value)

	return nil
}