Vertex AI のオンライン予測サービスを使用すると、独自の予測モデル エンドポイントに同期リクエストを行うことができます。
このページでは、低レイテンシでオンライン予測を提供できるように、モデルにリクエストを送信する方法について説明します。
始める前に
オンライン予測 API の使用を開始するには、プロジェクトと適切な認証情報が必要です。
オンライン予測を取得する前に、次の手順を行います。
- Vertex AI 用にプロジェクトを設定します。
オンライン予測にアクセスするために必要な権限を取得するには、プロジェクト IAM 管理者に Vertex AI 予測ユーザー(
vertex-ai-prediction-user)ロールの付与を依頼してください。このロールの詳細については、IAM 権限を準備するをご覧ください。
サポートされているコンテナのいずれかをターゲットとする予測モデルを作成してトレーニングします。
予測クラスタを作成し、プロジェクトで外部からのトラフィックが許可されていることを確認します。
予測モデルの
Endpointカスタム リソースの詳細を表示します。kubectl --kubeconfig PREDICTION_CLUSTER_KUBECONFIG get endpoint PREDICTION_ENDPOINT -n PROJECT_NAMESPACE -o jsonpath='{.status.endpointFQDN}'次のように置き換えます。
PREDICTION_CLUSTER_KUBECONFIG: 予測クラスタの kubeconfig ファイルへのパス。PREDICTION_ENDPOINT: エンドポイントの名前。PROJECT_NAMESPACE: 予測プロジェクトの Namespace の名前。
出力には、
statusフィールドが表示され、endpointFQDNフィールドにエンドポイントの完全修飾ドメイン名が表示されます。このエンドポイント URL パスを登録して、リクエストに使用します。
環境変数を設定する
Python スクリプトを使用してモデル エンドポイントにリクエストを送信し、プロジェクトでサービス アカウントを設定してプログラムで承認済み API 呼び出しを行う場合は、スクリプトで環境変数を定義して、実行時にサービス アカウント キーなどの値にアクセスできます。
Python スクリプトで必要な環境変数を設定する手順は次のとおりです。
JupyterLab ノートブックを作成して、オンライン予測 API を操作します。
JupyterLab ノートブックで Python スクリプトを作成します。
Python スクリプトに次のコードを追加します。
import os os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "APPLICATION_DEFAULT_CREDENTIALS_FILENAME"APPLICATION_DEFAULT_CREDENTIALS_FILENAMEは、プロジェクトで作成したサービス アカウント キーを含む JSON ファイルの名前に置き換えます(例:my-service-key.json)。Python スクリプトを
prediction.pyなどの名前で保存します。Python スクリプトを実行して環境変数を設定します。
python SCRIPT_NAMESCRIPT_NAMEは、Python スクリプトに付けた名前(prediction.pyなど)に置き換えます。
エンドポイントにリクエストを送信する
モデルのエンドポイントにリクエストを送信して、オンライン予測を取得します。
curl
curl リクエストを行う手順は次のとおりです。
リクエスト本文用の
request.jsonという名前の JSON ファイルを作成します。ターゲット コンテナに必要なリクエスト本文の詳細を使用して、オンライン予測の入力を追加してフォーマットする必要があります。
次のリクエストを行います。
curl -X POST -H "Content-Type: application/json; charset=utf-8" -H "Authorization: Bearer TOKEN" https://ENDPOINT_HOSTNAME:443/v1/model:predict -d @request.json次のように置き換えます。
TOKEN: 取得した認証トークン。ENDPOINT_HOSTNAME: オンライン予測リクエストのモデル エンドポイントの FQDN。
成功すると、オンライン予測リクエストに対する JSON レスポンスが返されます。
次の出力は、その例を示しています。
{
"predictions": [[-357.10849], [-171.621658]
]
}
レスポンスの詳細については、レスポンス本文の詳細をご覧ください。
Python
Python スクリプトからオンライン予測サービスを使用する手順は次のとおりです。
リクエスト本文用の
request.jsonという名前の JSON ファイルを作成します。ターゲット コンテナに必要なリクエスト本文の詳細を使用して、オンライン予測の入力を追加してフォーマットする必要があります。
作成した Python スクリプトに次のコードを追加します。
import json import os from typing import Sequence import grpc from absl import app from absl import flags import google from google.auth.transport import requests from google.protobuf import json_format from google.protobuf.struct_pb2 import Value from google.cloud.aiplatform_v1.services import prediction_service _INPUT = flags.DEFINE_string("input", None, "input", required=True) _ENDPOINT_HOSTNAME = flags.DEFINE_string("endpoint_hostname", None, "Prediction endpoint FQDN", required=True) _PROJECT_NAME = flags.DEFINE_string("project_name", None, "project name", required=True) _ENDPOINT_NAME = flags.DEFINE_string("endpoint_name", None, "endpoint name", required=True) os.environ["GRPC_DEFAULT_SSL_ROOTS_FILE_PATH"] = "path-to-ca-cert-file.cert" def get_sts_token(endpoint_hostname): creds = None try: creds, _ = google.auth.default() creds = creds.with_gdch_audience("https://"+endpoint_hostname+":443") req = requests.Request() creds.refresh(req) print("Got token: ") print(creds.token) except Exception as e: print("Caught exception" + str(e)) raise e return creds.token # predict_client_secure builds a client that requires TLS def predict_client_secure(endpoint_hostname, token): with open(os.environ["GRPC_DEFAULT_SSL_ROOTS_FILE_PATH"], 'rb') as f: channel_creds = grpc.ssl_channel_credentials(f.read()) call_creds = grpc.access_token_call_credentials(token) creds = grpc.composite_channel_credentials( channel_creds, call_creds, ) client = prediction_service.PredictionServiceClient( transport=prediction_service.transports.grpc.PredictionServiceGrpcTransport( channel=grpc.secure_channel(target=endpoint_hostname+":443", credentials=creds))) return client def predict_func(client, instances): # The endpoint resource name is required for authorization. # A wrong value might lead to an access denied error. endpoint_resource_name = f"projects/{_PROJECT_NAME.value}/locations/{_PROJECT_NAME.value}/endpoints/{_ENDPOINT_NAME.value}" resp = client.predict( endpoint=endpoint_resource_name, instances=instances, metadata=[("x-vertex-ai-endpoint-id", _ENDPOINT_NAME.value)] ) print(resp) def main(argv: Sequence[str]): del argv # Unused. with open(_INPUT.value) as json_file: data = json.load(json_file) instances = [json_format.ParseDict(s, Value()) for s in data["instances"]] token = get_sts_token(_ENDPOINT_HOSTNAME.value) client = predict_client_secure(_ENDPOINT_HOSTNAME.value, token) predict_func(client=client, instances=instances) if __name__=="__main__": app.run(main)Python スクリプトを
prediction.pyなどの名前で保存します。予測サーバーにリクエストを送信します。
python SCRIPT_NAME --input request.json \ --endpoint_hostname ENDPOINT_FQDN \ --project_name PROJECT_NAME \ --endpoint_name ENDPOINT_NAME \次のように置き換えます。
SCRIPT_NAME: Python スクリプトの名前(prediction.pyなど)。ENDPOINT_FQDN: オンライン予測リクエストのエンドポイントの完全修飾ドメイン名。PROJECT_NAME: エンドポイントのプロジェクト名。ENDPOINT_NAME: 呼び出すエンドポイントの名前。
成功すると、オンライン予測リクエストに対する JSON レスポンスが返されます。レスポンスの詳細については、レスポンス本文の詳細をご覧ください。