使用超参数调节来提高模型性能

本教程介绍了如何在 BigQuery ML 中使用超参数调优来调优机器学习模型并提高其性能。

您可以通过指定 CREATE MODEL 语句的 NUM_TRIALS 选项并结合使用其他特定于模型的选项,来执行超参数调优。设置这些选项后,BigQuery ML 会训练模型的多个版本(或试验),每个版本或试验的参数略有不同,并返回性能最佳的试验。

本教程使用公开的 tlc_yellow_trips_2018 示例表,其中包含 2018 年纽约市出租车行程的相关信息。

创建数据集

创建 BigQuery 数据集以存储机器学习模型。

控制台

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery 页面

  2. 探索器窗格中,点击您的项目名称。

  3. 点击 查看操作 > 创建数据集

  4. 创建数据集 页面上,执行以下操作:

    • 数据集 ID 部分,输入 bqml_tutorial

    • 位置类型部分,选择多区域,然后选择 US (multiple regions in United States)(美国[美国的多个区域])。

    • 保持其余默认设置不变,然后点击创建数据集

bq

如需创建新数据集,请使用带有 --location 标志的 bq mk 命令。 如需查看完整的潜在参数列表,请参阅 bq mk --dataset 命令参考文档。

  1. 创建一个名为 bqml_tutorial 的数据集,并将数据位置设置为 US,说明为 BigQuery ML tutorial dataset

    bq --location=US mk -d \
     --description "BigQuery ML tutorial dataset." \
     bqml_tutorial

    该命令使用的不是 --dataset 标志,而是 -d 快捷方式。如果省略 -d--dataset,该命令会默认创建一个数据集。

  2. 确认已创建数据集:

    bq ls

API

使用已定义的数据集资源调用 datasets.insert 方法。

{
  "datasetReference": {
     "datasetId": "bqml_tutorial"
  }
}

BigQuery DataFrame

在尝试此示例之前,请按照《BigQuery 快速入门:使用 BigQuery DataFrames》中的 BigQuery DataFrames 设置说明进行操作。如需了解详情,请参阅 BigQuery DataFrames 参考文档

如需向 BigQuery 进行身份验证,请设置应用默认凭证。如需了解详情,请参阅为本地开发环境设置 ADC

import google.cloud.bigquery

bqclient = google.cloud.bigquery.Client()
bqclient.create_dataset("bqml_tutorial", exists_ok=True)

创建训练数据表

根据 tlc_yellow_trips_2018 表数据的一部分创建训练数据表。

请按照以下步骤创建表:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    CREATE OR REPLACE TABLE `bqml_tutorial.taxi_tip_input`
    AS
    SELECT * EXCEPT (tip_amount), tip_amount AS label
    FROM
      `bigquery-public-data.new_york_taxi_trips.tlc_yellow_trips_2018`
    WHERE
      tip_amount IS NOT NULL
    LIMIT 100000;

创建基准线性回归模型

创建一个不含超参数调优的线性回归模型,并根据 taxi_tip_input 表数据对其进行训练。

请按照以下步骤创建模型:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    CREATE OR REPLACE MODEL `bqml_tutorial.baseline_taxi_tip_model`
      OPTIONS (
        MODEL_TYPE = 'LINEAR_REG'
      )
    AS
    SELECT
      *
    FROM
      `bqml_tutorial.taxi_tip_input`;

    查询大约需要 2 分钟才能完成。

评估基准模型

使用 ML.EVALUATE 函数评估模型的性能。ML.EVALUATE 函数根据模型训练期间计算出的评估指标来评估模型返回的预测内容分级。

请按照以下步骤评估模型:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    SELECT *
    FROM
      ML.EVALUATE(MODEL `bqml_tutorial.baseline_taxi_tip_model`);

    结果类似于以下内容:

    +---------------------+--------------------+------------------------+-----------------------+---------------------+---------------------+
    | mean_absolute_error | mean_squared_error | mean_squared_log_error | median_absolute_error |      r2_score       | explained_variance  |
    +---------------------+--------------------+------------------------+-----------------------+---------------------+---------------------+
    |  2.5853895559690323 | 23760.416358496139 |   0.017392406523370374 | 0.0044248227819481123 | -1934.5450533482465 | -1934.3513857946277 |
    +---------------------+--------------------+------------------------+-----------------------+---------------------+---------------------+
    

基准模型的 r2_score 值为负,表示模型与数据的拟合度较差;R2 得分越接近 1,模型拟合度就越好。

创建具有超参数调优的线性回归模型

创建一个具有超参数调优的线性回归模型,并根据 taxi_tip_input 表数据对其进行训练。

您可以在 CREATE MODEL 语句中使用以下超参数调优选项:

  • NUM_TRIALS 选项,可将试验次数设置为 20。
  • MAX_PARALLEL_TRIALS 选项,可在每个训练作业中运行 2 次试验,总共运行 10 个作业和 20 次试验。这样可缩短所需的训练时间。不过,两个并发试验不会从彼此的训练结果中受益。
  • L1_REG 选项,用于在不同的试验中尝试不同的 L1 正则化值。 L1 正则化会从模型中移除不相关的特征,这有助于防止出现过拟合

模型支持的其他超参数调优选项使用其默认值,如下所示:

  • L1_REG0
  • HPARAM_TUNING_ALGORITHM'VIZIER_DEFAULT'
  • HPARAM_TUNING_OBJECTIVES['R2_SCORE']

请按照以下步骤创建模型:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    CREATE OR REPLACE MODEL `bqml_tutorial.hp_taxi_tip_model`
      OPTIONS (
        MODEL_TYPE = 'LINEAR_REG',
        NUM_TRIALS = 20,
        MAX_PARALLEL_TRIALS = 2,
        L1_REG = HPARAM_RANGE(0, 5))
    AS
    SELECT
      *
    FROM
      `bqml_tutorial.taxi_tip_input`;

    该查询大约需要 20 分钟才能完成。

获取有关训练试验的信息

使用 ML.TRIAL_INFO 函数可获取有关所有试验的信息,包括其超参数值、目标和状态。此函数还会根据此信息返回有关哪个试验性能最佳的信息。

请按照以下步骤获取试验信息:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    SELECT *
    FROM
      ML.TRIAL_INFO(MODEL `bqml_tutorial.hp_taxi_tip_model`)
    ORDER BY is_optimal DESC;

    结果类似于以下内容:

    +----------+-------------------------------------+-----------------------------------+--------------------+--------------------+-----------+---------------+------------+
    | trial_id |           hyperparameters           | hparam_tuning_evaluation_metrics  |   training_loss    |     eval_loss      |  status   | error_message | is_optimal |
    +----------+-------------------------------------+-----------------------------------+--------------------+--------------------+-----------+---------------+------------+
    |        7 |      {"l1_reg":"4.999999999999985"} |  {"r2_score":"0.653653627638174"} | 4.4677841296238165 |  4.478469742512195 | SUCCEEDED | NULL          |       true |
    |        2 |  {"l1_reg":"2.402163664510254E-11"} | {"r2_score":"0.6532493667964732"} |  4.457692508421795 |  4.483697081650438 | SUCCEEDED | NULL          |      false |
    |        3 |  {"l1_reg":"1.2929452948742316E-7"} |  {"r2_score":"0.653249366811995"} |   4.45769250849513 |  4.483697081449748 | SUCCEEDED | NULL          |      false |
    |        4 |  {"l1_reg":"2.5787102060628228E-5"} | {"r2_score":"0.6532493698925899"} |  4.457692523040582 |  4.483697041615808 | SUCCEEDED | NULL          |      false |
    |      ... |                             ...     |                           ...     |              ...   |             ...    |       ... |          ...  |        ... |
    +----------+-------------------------------------+-----------------------------------+--------------------+--------------------+-----------+---------------+------------+
    

    is_optimal 列值表示,试验 7 是调优返回的最佳模型。

评估调优的模型试验

使用 ML.EVALUATE 函数评估试验的性能。ML.EVALUATE 函数根据所有试验训练期间计算的评估指标,来评估模型返回的预测内容分级。

请按照以下步骤评估试验:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    SELECT *
    FROM
      ML.EVALUATE(MODEL `bqml_tutorial.hp_taxi_tip_model`)
    ORDER BY r2_score DESC;

    结果类似于以下内容:

    +----------+---------------------+--------------------+------------------------+-----------------------+--------------------+--------------------+
    | trial_id | mean_absolute_error | mean_squared_error | mean_squared_log_error | median_absolute_error |      r2_score      | explained_variance |
    +----------+---------------------+--------------------+------------------------+-----------------------+--------------------+--------------------+
    |        7 |   1.151814398002232 |  4.109811493266523 |     0.4918733252641176 |    0.5736103414025084 | 0.6652110305659145 | 0.6652144696114834 |
    |       19 |  1.1518143358927102 |  4.109811921460791 |     0.4918672150119582 |    0.5736106106914161 | 0.6652109956848206 | 0.6652144346901685 |
    |        8 |   1.152747850702547 |  4.123625876152422 |     0.4897808307399327 |    0.5731702310239184 | 0.6640856984144734 |  0.664088410199906 |
    |        5 |   1.152895108945439 |  4.125775524878872 |    0.48939088205957937 |    0.5723300569616766 | 0.6639105860807425 | 0.6639132416838652 |
    |      ... |                ...  |                ... |                    ... |                   ... |                ... |                ... |
    +----------+---------------------+--------------------+------------------------+-----------------------+--------------------+--------------------+
    

    最佳模型(即试验 7)的 r2_score 值为 0.66521103056591446,这表明与基准模型相比,效果有了显著改进。

您可以通过在 ML.EVALUATE 函数中指定 TRIAL_ID 参数来评估特定试验。

如需详细了解 ML.TRIAL_INFO 目标与 ML.EVALUATE 评估指标之间的区别,请参阅模型部署函数

使用调优的模型预测出租车小费

使用调优返回的最佳模型来预测不同出租车行程的小费。ML.PREDICT 函数会自动使用最佳模型,除非您通过指定 TRIAL_ID 参数来选择其他试验。预测结果会返回在 predicted_label 列中。

请按照以下步骤获取预测结果:

  1. 在 Google Cloud 控制台中,前往 BigQuery 页面。

    转到 BigQuery

  2. 在查询编辑器中,粘贴以下查询,然后点击运行

    SELECT *
    FROM
      ML.PREDICT(
        MODEL `bqml_tutorial.hp_taxi_tip_model`,
        (
          SELECT
            *
          FROM
            `bqml_tutorial.taxi_tip_input`
          LIMIT 5
        ));

    结果类似于以下内容:

    +----------+--------------------+-----------+---------------------+---------------------+-----------------+---------------+-----------+--------------------+--------------+-------------+-------+---------+--------------+---------------+--------------+--------------------+---------------------+----------------+-----------------+-------+
    | trial_id |  predicted_label   | vendor_id |   pickup_datetime   |  dropoff_datetime   | passenger_count | trip_distance | rate_code | store_and_fwd_flag | payment_type | fare_amount | extra | mta_tax | tolls_amount | imp_surcharge | total_amount | pickup_location_id | dropoff_location_id | data_file_year | data_file_month | label |
    +----------+--------------------+-----------+---------------------+---------------------+-----------------+---------------+-----------+--------------------+--------------+-------------+-------+---------+--------------+---------------+--------------+--------------------+---------------------+----------------+-----------------+-------+
    |        7 |  1.343367839584448 | 2         | 2018-01-15 18:55:15 | 2018-01-15 18:56:18 |               1 |             0 | 1         | N                  | 1            |           0 |     0 |       0 |            0 |             0 |            0 | 193                | 193                 |           2018 |               1 |     0 |
    |        7 | -1.176072791783461 | 1         | 2018-01-08 10:26:24 | 2018-01-08 10:26:37 |               1 |             0 | 5         | N                  | 3            |        0.01 |     0 |       0 |            0 |           0.3 |         0.31 | 158                | 158                 |           2018 |               1 |     0 |
    |        7 |  3.839580104168765 | 1         | 2018-01-22 10:58:02 | 2018-01-22 12:01:11 |               1 |          16.1 | 1         | N                  | 1            |        54.5 |     0 |     0.5 |            0 |           0.3 |         55.3 | 140                | 91                  |           2018 |               1 |     0 |
    |        7 |  4.677393985230036 | 1         | 2018-01-16 10:14:35 | 2018-01-16 11:07:28 |               1 |            18 | 1         | N                  | 2            |        54.5 |     0 |     0.5 |            0 |           0.3 |         55.3 | 138                | 67                  |           2018 |               1 |     0 |
    |        7 |  7.938988937253062 | 2         | 2018-01-16 07:05:15 | 2018-01-16 08:06:31 |               1 |          17.8 | 1         | N                  | 1            |        54.5 |     0 |     0.5 |            0 |           0.3 |        66.36 | 132                | 255                 |           2018 |               1 | 11.06 |
    +----------+--------------------+-----------+---------------------+---------------------+-----------------+---------------+-----------+--------------------+--------------+-------------+-------+---------+--------------+---------------+--------------+--------------------+---------------------+----------------+-----------------+-------+