开发自定义代理

Vertex AI Agent Engine 中的代理模板定义为 Python 类。以下步骤展示了如何创建自定义模板,以用于实例化可在 Vertex AI 上部署的代理:

  1. 基本示例
  2. (可选)流式传输响应
  3. (可选)注册自定义方法
  4. (可选)提供类型注解
  5. (可选)将跟踪记录发送到 Cloud Trace
  6. (可选)使用环境变量
  7. (可选)与 Secret Manager 集成
  8. (可选)处理凭证
  9. (可选)处理错误

基本示例

举一个基本示例,以下 Python 类是一个模板,用于实例化可在 Vertex AI 上部署的代理(您可以为 CLASS_NAME 变量赋值,例如 MyAgent):

from typing import Callable, Sequence

class CLASS_NAME:
    def __init__(
        self,
        model: str,
        tools: Sequence[Callable],
        project: str,
        location: str,
    ):
        self.model_name = model
        self.tools = tools
        self.project = project
        self.location = location

    def set_up(self):
        import vertexai
        from langchain_google_vertexai import ChatVertexAI
        from langgraph.prebuilt import create_react_agent

        vertexai.init(project=self.project, location=self.location)

        model = ChatVertexAI(model_name=self.model_name)
        self.graph = create_react_agent(model, tools=self.tools)

    def query(self, **kwargs):
        return self.graph.invoke(**kwargs)

部署考虑事项

编写 Python 类时,以下三种方法非常重要:

  1. __init__()
    • 此方法仅用于代理配置参数。例如,您可以使用此方法从用户那里收集模型参数和安全属性作为输入参数。您还可以使用此方法收集项目 ID、区域、应用凭据和 API 密钥等参数。
    • 构造函数返回的对象必须是“可序列化”的,才能部署到 Vertex AI Agent Engine。因此,您应在 .set_up 方法中初始化服务客户端并建立与数据库的连接,而不是在 __init__ 方法中执行这些操作。
    • 此方法为可选方法。如果未指定,Vertex AI 会使用该类的默认 Python 构造函数。
  2. set_up()
    • 您必须使用此方法来定义代理初始化逻辑。例如,您可以使用此方法建立与数据库或依赖服务的连接、导入依赖软件包或预计算用于处理查询的数据。
    • 此方法为可选方法。如果未指定,Vertex AI 会假定代理无需在处理用户查询之前调用 .set_up 方法。
  3. query()/stream_query()
    • 使用 query() 可将完整响应作为单个结果返回。
    • 使用 stream_query() 可在响应可用时以分块的形式返回响应,从而实现流式传输体验。stream_query 方法必须返回可迭代对象(例如生成器)才能实现流式传输。
    • 如果您希望支持与代理之间的单次响应和流式传输交互,可以同时实现这两种方法。
    • 您应为此方法提供一个清晰的文档字符串,用于定义它的作用、记录其属性以及为其输入提供类型注解。请避免在 querystream_query 方法中使用可变参数。

在本地实例化代理

您可以使用以下代码创建代理的本地实例:

agent = CLASS_NAME(
    model=model,  # Required.
    tools=[get_exchange_rate],  # Optional.
    project="PROJECT_ID",
    location="LOCATION",
)
agent.set_up()

测试 query 方法

您可以通过向本地实例发送查询来测试代理:

response = agent.query(
    input="What is the exchange rate from US dollars to Swedish currency?"
)

print(response)

响应是一个类似于以下内容的字典:

{"input": "What is the exchange rate from US dollars to Swedish currency?",
 # ...
 "output": "For 1 US dollar you will get 10.7345 Swedish Krona."}

异步查询

如需异步响应查询,您可以定义一个返回 Python 协程的方法(例如 async_query)。例如,以下模板扩展了基本示例以进行异步响应,并且可在 Vertex AI 上部署:

class AsyncAgent(CLASS_NAME):

    async def async_query(self, **kwargs):
        from langchain.load.dump import dumpd

        for chunk in self.graph.ainvoke(**kwargs):
            yield dumpd(chunk)

agent = AsyncAgent(
    model=model,                # Required.
    tools=[get_exchange_rate],  # Optional.
    project="PROJECT_ID",
    location="LOCATION",
)
agent.set_up()

测试 async_query 方法

您可以通过调用 async_query 方法在本地测试代理。示例如下:

response = await agent.async_query(
    input="What is the exchange rate from US dollars to Swedish Krona today?"
)
print(response)

响应是一个类似于以下内容的字典:

{"input": "What is the exchange rate from US dollars to Swedish currency?",
 # ...
 "output": "For 1 US dollar you will get 10.7345 Swedish Krona."}

流式响应

如需流式传输对查询的响应,您可以定义一个会生成响应的名为 stream_query 的方法。例如,以下模板扩展了基本示例以流式传输响应,并且可在 Vertex AI 上部署:

from typing import Iterable

class StreamingAgent(CLASS_NAME):

    def stream_query(self, **kwargs) -> Iterable:
        from langchain.load.dump import dumpd

        for chunk in self.graph.stream(**kwargs):
            yield dumpd(chunk)

agent = StreamingAgent(
    model=model,                # Required.
    tools=[get_exchange_rate],  # Optional.
    project="PROJECT_ID",
    location="LOCATION",
)
agent.set_up()

使用流式传输 API 时,请注意以下一些关键事项:

  • 超时上限:流式传输响应的超时上限为 10 分钟。如果您的代理需要更长的处理时间,请考虑将任务分解为较小的分块。
  • 流式传输模型和链:LangChain 的 Runnable 接口支持流式传输,因此您不仅可以流式传输来自代理的响应,还可以流式传输来自模型和链的响应。
  • LangChain 兼容性:请注意,目前不支持异步方法,例如 LangChain 的 astream_event 方法。
  • 限制内容生成:如果您遇到背压问题(即提供方生成数据的速度快于使用方处理数据的速度),则应限制内容生成速率。这有助于防止缓冲区溢出,并确保流畅的流式传输体验。

测试 stream_query 方法

您可以通过调用 stream_query 方法并迭代结果,在本地测试流式传输查询。示例如下:

import pprint

for chunk in agent.stream_query(
    input="What is the exchange rate from US dollars to Swedish currency?"
):
    # Use pprint with depth=1 for a more concise, high-level view of the
    # streamed output.
    # To see the full content of the chunk, use:
    # print(chunk)
    pprint.pprint(chunk, depth=1)

此代码会在响应生成时输出响应的每个分块。输出可能如下所示:

{'actions': [...], 'messages': [...]}
{'messages': [...], 'steps': [...]}
{'messages': [...],
 'output': 'The exchange rate from US dollars to Swedish currency is 1 USD to '
           '10.5751 SEK. \n'}

在此示例中,每个分块都包含有关响应的不同信息,例如代理执行的操作、交换的消息和最终输出。

异步流式传输响应

如需异步流式传输响应,您可以定义一个返回异步生成器的方法(例如 async_stream_query)。例如,以下模板扩展了基本示例以异步流式传输响应,并且可在 Vertex AI 上部署:

class AsyncStreamingAgent(CLASS_NAME):

    async def async_stream_query(self, **kwargs):
        from langchain.load.dump import dumpd

        for chunk in self.graph.astream(**kwargs):
            yield dumpd(chunk)

agent = AsyncStreamingAgent(
    model=model,                # Required.
    tools=[get_exchange_rate],  # Optional.
    project="PROJECT_ID",
    location="LOCATION",
)
agent.set_up()

测试 async_stream_query 方法

与用于测试流式传输查询的代码类似,您可以通过调用 async_stream_query 方法并迭代结果,在本地测试代理。示例如下:

import pprint

async for chunk in agent.async_stream_query(
    input="What is the exchange rate from US dollars to Swedish currency?"
):
    # Use pprint with depth=1 for a more concise, high-level view of the
    # streamed output.
    # To see the full content of the chunk, use:
    # print(chunk)
    pprint.pprint(chunk, depth=1)

此代码会在响应生成时输出响应的每个分块。输出可能如下所示:

{'actions': [...], 'messages': [...]}
{'messages': [...], 'steps': [...]}
{'messages': [...],
 'output': 'The exchange rate from US dollars to Swedish currency is 1 USD to '
           '10.5751 SEK. \n'}

注册自定义方法

默认情况下,方法 querystream_query 在所部署的代理中注册为操作。您可以使用 register_operations 方法替换默认行为并定义要注册的操作集。操作可以注册为标准(以空字符串 "" 表示)或流式传输 ("stream") 执行模式。

如需注册多个操作,您可以定义一个名为 register_operations 的方法,其中列出在部署代理时可供用户使用的方法。在以下示例代码中,register_operations 方法会使所部署的代理将 queryget_state 注册为同步运行的操作,并将 stream_queryget_state_history 注册为流式传输响应的操作:

from typing import Iterable

class CustomAgent(StreamingAgent):

    def get_state(self) -> dict: # new synchronous method
        return self.graph.get_state(**kwargs)._asdict()

    def get_state_history(self) -> Iterable: # new streaming operation
        for state_snapshot in self.graph.get_state_history(**kwargs):
            yield state_snapshot._asdict()

    def register_operations(self):
        return {
            # The list of synchronous operations to be registered
            "": ["query", "get_state"],
            # The list of streaming operations to be registered
            "stream": ["stream_query", "get_state_history"],
        }

您可以直接在代理的本地实例上调用自定义方法来测试这些方法,这与测试 querystream_query 方法的方式类似。

提供类型注解

您可以使用类型注解来指定代理方法的预期输入和输出类型。部署代理时,代理支持的操作的输入和输出中仅支持可序列化为 JSON 的类型。可以使用 TypedDict 或 Pydantic 模型对输入和输出的架构进行注解。

在以下示例中,我们将输入注解为 TypedDict,并使用 ._asdict() 方法将 .get_state 的原始输出(即 NamedTuple)转换为可序列化的字典:

from typing import Any, Dict, TypedDict

# schemas.py
class RunnableConfig(TypedDict, total=False):
    metadata: Dict[str, Any]
    configurable: Dict[str, Any]

# agents.py
class AnnotatedAgent(CLASS_NAME):

    def get_state(self, config: RunnableConfig) -> dict:
        return self.graph.get_state(config=config)._asdict()

    def register_operations(self):
        return {"": ["query", "get_state"]}

将跟踪记录发送到 Cloud Trace

如需使用支持 OpenTelemetry 的插桩库将跟踪记录发送到 Cloud Trace,您可以在 .set_up 方法中导入并初始化这些库。对于常见的代理框架,您或许可以将 Open Telemetry Google Cloud 集成OpenInferenceOpenLLMetry 等插桩框架结合使用。

例如,以下模板是对基本示例的修改,以将跟踪记录导出到 Cloud Trace:

OpenInference

首先,通过运行以下命令,使用 pip 安装所需的软件包

pip install openinference-instrumentation-langchain==0.1.34

接下来,导入并初始化插桩器:

from typing import Callable, Sequence

class CLASS_NAME:
    def __init__(
        self,
        model: str,
        tools: Sequence[Callable],
        project: str,
        location: str,
    ):
        self.model_name = model
        self.tools = tools
        self.project = project
        self.location = location

    def set_up(self):
        # The additional code required for tracing instrumentation.
        from opentelemetry import trace
        from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
        from opentelemetry.sdk.trace import TracerProvider
        from opentelemetry.sdk.trace.export import SimpleSpanProcessor
        from openinference.instrumentation.langchain import LangChainInstrumentor
        import google.cloud.trace_v2 as cloud_trace_v2
        import google.auth

        credentials, _ = google.auth.default()

        trace.set_tracer_provider(TracerProvider())
        cloud_trace_exporter = CloudTraceSpanExporter(
            project_id=self.project,
            client=cloud_trace_v2.TraceServiceClient(
                credentials=credentials.with_quota_project(self.project),
            ),
        )
        trace.get_tracer_provider().add_span_processor(
            SimpleSpanProcessor(cloud_trace_exporter)
        )
        LangChainInstrumentor().instrument()
        # end of additional code required

        import vertexai
        from langchain_google_vertexai import ChatVertexAI
        from langgraph.prebuilt import create_react_agent

        vertexai.init(project=self.project, location=self.location)

        model = ChatVertexAI(model_name=self.model_name)
        self.graph = create_react_agent(model, tools=self.tools)

    def query(self, **kwargs):
        return self.graph.invoke(**kwargs)

OpenLLMetry

首先,通过运行以下命令,使用 pip 安装所需的软件包

pip install opentelemetry-instrumentation-langchain==0.38.10

接下来,导入并初始化插桩器:

from typing import Callable, Sequence

class CLASS_NAME:
    def __init__(
        self,
        model: str,
        tools: Sequence[Callable],
        project: str,
        location: str,
    ):
        self.model_name = model
        self.tools = tools
        self.project = project
        self.location = location

    def set_up(self):
        # The additional code required for tracing instrumentation.
        from opentelemetry import trace
        from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
        from opentelemetry.sdk.trace import TracerProvider
        from opentelemetry.sdk.trace.export import SimpleSpanProcessor
        from opentelemetry.instrumentation.langchain import LangchainInstrumentor
        import google.cloud.trace_v2 as cloud_trace_v2
        import google.auth

        credentials, _ = google.auth.default()

        trace.set_tracer_provider(TracerProvider())
        cloud_trace_exporter = CloudTraceSpanExporter(
            project_id=self.project,
            client=cloud_trace_v2.TraceServiceClient(
                credentials=credentials.with_quota_project(self.project),
            ),
        )
        trace.get_tracer_provider().add_span_processor(
            SimpleSpanProcessor(cloud_trace_exporter)
        )
        LangchainInstrumentor().instrument()
        # end of additional code required

        import vertexai
        from langchain_google_vertexai import ChatVertexAI
        from langgraph.prebuilt import create_react_agent

        vertexai.init(project=self.project, location=self.location)

        model = ChatVertexAI(model_name=self.model_name)
        self.graph = create_react_agent(model, tools=self.tools)

    def query(self, **kwargs):
        return self.graph.invoke(**kwargs)

使用环境变量

如需设置环境变量,请确保在开发期间可以通过 os.environ 使用这些变量,并在部署代理时按照定义环境变量中的说明操作。

与 Secret Manager 集成

如需与 Secret Manager 集成,请执行以下操作:

  1. 运行以下命令来安装客户端库

    pip install google-cloud-secret-manager
  2. 按照为所部署的代理授予角色中的说明,通过 Google Cloud 控制台向服务账号授予“Secret Manager Secret Accessor”角色 (roles/secretmanager.secretAccessor)。

  3. .set_up 方法中导入并初始化客户端,并在需要时获取相应的 Secret。例如,以下模板是对基本示例的修改,以使用存储在 Secret Manager 中ChatAnthropic API 密钥:

from typing import Callable, Sequence

class CLASS_NAME:
    def __init__(
        self,
        model: str,
        tools: Sequence[Callable],
        project: str,
    ):
        self.model_name = model
        self.tools = tools
        self.project = project
        self.secret_id = secret_id # <- new

    def set_up(self):
        from google.cloud import secretmanager
        from langchain_anthropic import ChatAnthropic
        from langgraph.prebuilt import create_react_agent

        # Get the API Key from Secret Manager here.
        self.secret_manager_client = secretmanager.SecretManagerServiceClient()
        secret_version = self.secret_manager_client.access_secret_version(request={
            "name": "projects/PROJECT_ID/secrets/SECRET_ID/versions/SECRET_VERSION",
        })
        # Use the API Key from Secret Manager here.
        model = ChatAnthropic(
            model_name=self.model_name,
            model_kwargs={"api_key": secret_version.payload.data.decode()},  # <- new
        )
        self.graph = create_react_agent(model, tools=self.tools)

    def query(self, **kwargs):
        return self.graph.invoke(**kwargs)

处理凭证

部署代理时,可能需要处理不同类型的凭证:

  1. 应用默认凭证 (ADC)(通常来自服务账号),
  2. OAuth(通常来自用户账号),
  3. 来自外部账号(工作负载身份联合)的凭证的身份提供方

应用默认凭据

import google.auth

credentials, project = google.auth.default(
    scopes=["https://www.googleapis.com/auth/cloud-platform"]
)

您可以在代码中按以下方式使用它:

from typing import Callable, Sequence

class CLASS_NAME:
    def __init__(
        self,
        model: str = "meta/llama3-405b-instruct-maas",
        tools: Sequence[Callable],
        location: str,
        project: str,
    ):
        self.model_name = model
        self.tools = tools
        self.project = project
        self.endpoint = f"https://{location}-aiplatform.googleapis.com"
        self.base_url = f'{self.endpoint}/v1beta1/projects/{project}/locations/{location}/endpoints/openapi'

    def query(self, **kwargs):
        import google.auth
        from langchain_openai import ChatOpenAI
        from langgraph.prebuilt import create_react_agent

        # Note: the credential lives for 1 hour by default.
        # After expiration, it must be refreshed.
        creds, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
        creds.refresh(google.auth.transport.requests.Request())

        model = ChatOpenAI(
            model=self.model_name,
            base_url=self.base_url,
            api_key=creds.token,  # Use the token from the credentials here.
        )
        graph = create_react_agent(model, tools=self.tools)
        return graph.invoke(**kwargs)

如需了解详情,请参阅应用默认凭证的工作原理

OAuth

用户凭证通常使用 OAuth 2.0 获取。

如果您有访问令牌(例如来自 oauthlib),则可以创建 google.oauth2.credentials.Credentials 实例。此外,如果您获取了刷新令牌,还可以指定刷新令牌和令牌 URI,以允许自动刷新凭证:

credentials = google.oauth2.credentials.Credentials(
    token="ACCESS_TOKEN",
    refresh_token="REFRESH_TOKEN",  # Optional
    token_uri="TOKEN_URI",          # E.g. "https://oauth2.googleapis.com/token"
    client_id="CLIENT_ID",          # Optional
    client_secret="CLIENT_SECRET"   # Optional
)

在此处,TOKEN_URICLIENT_IDCLIENT_SECRET 基于创建 OAuth 客户端凭证中的内容。

如果您没有访问令牌,可以使用 google_auth_oauthlib.flow 执行 OAuth 2.0 授权授予流程,以获取相应的 google.oauth2.credentials.Credentials 实例:

from google.cloud import secretmanager
from google_auth_oauthlib.flow import InstalledAppFlow
import json

# Get the client config from Secret Manager here.
secret_manager_client = secretmanager.SecretManagerServiceClient()
secret_version = client.access_secret_version(request={
    "name": "projects/PROJECT_ID/secrets/SECRET_ID/versions/SECRET_VERSION",
})
client_config = json.loads(secret_version.payload.data.decode())

# Create flow instance to manage the OAuth 2.0 Authorization Grant Flow steps.
flow = InstalledAppFlow.from_client_config(
    client_config,
    scopes=['https://www.googleapis.com/auth/cloud-platform'],
    state="OAUTH_FLOW_STATE"  # from flow.authorization_url(...)
)

# You can get the credentials from the flow object.
credentials: google.oauth2.credentials.Credentials = flow.credentials

# After obtaining the credentials, you can then authorize API requests on behalf
# of the given user or service account. For example, to authorize API requests
# to vertexai services, you'll specify it in vertexai.init(credentials=)
import vertexai

vertexai.init(
    project="PROJECT_ID",
    location="LOCATION",
    credentials=credentials, # specify the credentials here
)

如需了解详情,请参阅 google_auth_oauthlib.flow 模块的文档

身份提供方

如果要使用邮箱/密码、电话号码、社交服务提供方(例如 Google、Facebook 或 GitHub)或自定义身份验证机制对用户进行身份验证,则可以使用 Identity PlatformFirebase 身份验证,也可以使用任何支持 OpenID Connect (OIDC) 的身份提供方。

如需了解详情,请参阅从 OIDC 身份提供方访问资源

处理错误

为确保以结构化 JSON 格式返回 API 错误,我们建议在代理代码中使用 try...except 块(可以抽象为修饰器)来实现错误处理。

虽然 Vertex AI Agent Engine 可以在内部处理各种状态代码,但 Python 缺乏一种标准化方式在所有异常类型中使用关联的 HTTP 状态代码表示错误。尝试在底层服务中将所有可能的 Python 异常映射到 HTTP 状态会很复杂,并且难以维护。

一种可伸缩性更强的方法是在代理方法中显式捕获相关异常,或使用 error_wrapper 等可重用的修饰器。然后,您可以关联适当的状态代码(例如,通过向自定义异常添加 codeerror 属性或专门处理标准异常),并将错误格式化为 JSON 字典以用于返回值。这只需要在代理方法本身中进行极少的代码更改,通常只需要添加修饰器。

以下示例展示了如何在代理中实现错误处理:

from functools import wraps
import json

def error_wrapper(func):
    @wraps(func)  # Preserve original function metadata
    def wrapper(*args, **kwargs):
        try:
            # Execute the original function with its arguments
            return func(*args, **kwargs)
        except Exception as err:
            error_code = getattr(err, 'code')
            error_message = getattr(err, 'error')

            # Construct the error response dictionary
            error_response = {
                "error": {
                    "code": error_code,
                    "message": f"'{func.__name__}': {error_message}"
                }
            }
            # Return the Python dictionary directly.
            return error_response

    return wrapper

# Example exception
class SessionNotFoundError(Exception):
    def __init__(self, session_id, message="Session not found"):
        self.code = 404
        self.error = f"{message}: {session_id}"
        super().__init__(self.error)

# Example Agent Class
class MyAgent:
    @error_wrapper
    def get_session(self, session_id: str):
        # Simulate the condition where the session isn't found
        raise SessionNotFoundError(session_id=session_id)


# Example Usage: Session Not Found
agent = MyAgent()
error_result = agent.get_session(session_id="nonexistent_session_123")
print(json.dumps(error_result, indent=2))

上面的代码会生成以下输出: json { "error": { "code": 404, "message": "Invocation error in 'get_session': Session not found: nonexistent_session_123" } }

后续步骤