ハイパーパラメータ チューニングでモデル性能を強化する

このチュートリアルでは、BigQuery ML でハイパーパラメータ チューニングを使用して ML モデルをチューニングし、パフォーマンスを改善する方法について説明します。

ハイパーパラメータ チューニングを行うには、CREATE MODEL ステートメントの NUM_TRIALS オプションを、他のモデル固有のオプションと組み合わせて指定します。これらのオプションを設定すると、BigQuery ML は、それぞれパラメータがわずかに異なるモデルの複数のバージョン(トライアル)をトレーニングし、パフォーマンスが最も高いトライアルを返します。

このチュートリアルでは、2018 年のニューヨーク市のタクシー乗車に関する情報が含まれている一般公開の tlc_yellow_trips_2018 サンプル テーブルを使用します。

データセットを作成する

ML モデルを保存する BigQuery データセットを作成します。

コンソール

  1. Google Cloud コンソールで、[BigQuery] ページに移動します。

    [BigQuery] ページに移動

  2. [エクスプローラ] ペインで、プロジェクト名をクリックします。

  3. [アクションを表示] > [データセットを作成] をクリックします。

  4. [データセットを作成する] ページで、次の操作を行います。

    • [データセット ID] に「bqml_tutorial」と入力します。

    • [ロケーション タイプ] で [マルチリージョン] を選択してから、[US(米国の複数のリージョン)] を選択します。

    • 残りのデフォルトの設定は変更せず、[データセットを作成] をクリックします。

bq

新しいデータセットを作成するには、--location フラグを指定した bq mk コマンドを使用します。使用可能なパラメータの一覧については、bq mk --dataset コマンドのリファレンスをご覧ください。

  1. データの場所が US に設定され、BigQuery ML tutorial dataset という説明の付いた、bqml_tutorial という名前のデータセットを作成します。

    bq --location=US mk -d \
     --description "BigQuery ML tutorial dataset." \
     bqml_tutorial

    このコマンドでは、--dataset フラグの代わりに -d ショートカットを使用しています。-d--dataset を省略した場合、このコマンドはデフォルトでデータセットを作成します。

  2. データセットが作成されたことを確認します。

    bq ls

API

定義済みのデータセット リソースを使用して datasets.insert メソッドを呼び出します。

{
  "datasetReference": {
     "datasetId": "bqml_tutorial"
  }
}

BigQuery DataFrames

このサンプルを試す前に、BigQuery DataFrames を使用した BigQuery クイックスタートの手順に沿って BigQuery DataFrames を設定してください。詳細については、BigQuery DataFrames のリファレンス ドキュメントをご覧ください。

BigQuery に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の ADC の設定をご覧ください。

import google.cloud.bigquery

bqclient = google.cloud.bigquery.Client()
bqclient.create_dataset("bqml_tutorial", exists_ok=True)

トレーニング データのテーブルを作成する

tlc_yellow_trips_2018 テーブルデータのサブセットに基づいて、トレーニング データのテーブルを作成します。

次の手順でテーブルを作成します。

  1. Google Cloud コンソールで、[BigQuery] ページに移動します。

    [BigQuery] に移動

  2. クエリエディタに次のクエリを貼り付け、[実行] をクリックします。

    CREATE OR REPLACE TABLE `bqml_tutorial.taxi_tip_input`
    AS
    SELECT * EXCEPT (tip_amount), tip_amount AS label
    FROM
      `bigquery-public-data.new_york_taxi_trips.tlc_yellow_trips_2018`
    WHERE
      tip_amount IS NOT NULL
    LIMIT 100000;

ベースラインの線形回帰モデルを作成する

ハイパーパラメータのチューニングなしで線形回帰モデルを作成し、taxi_tip_input テーブルデータでトレーニングします。

次の手順でモデルを作成します。

  1. Google Cloud コンソールで、[BigQuery] ページに移動します。

    [BigQuery] に移動

  2. クエリエディタに次のクエリを貼り付け、[実行] をクリックします。

    CREATE OR REPLACE MODEL `bqml_tutorial.baseline_taxi_tip_model`
      OPTIONS (
        MODEL_TYPE = 'LINEAR_REG'
      )
    AS
    SELECT
      *
    FROM
      `bqml_tutorial.taxi_tip_input`;

    クエリの完了には、約 2 分かかります。

ベースライン モデルを評価する

ML.EVALUATE 関数を使用してモデルのパフォーマンスを評価します。 ML.EVALUATE 関数は、モデルから返された予測コンテンツ評価を、モデル トレーニング中に計算された評価指標と照らし合わせて評価します。

次の手順でモデルを評価します。

  1. Google Cloud コンソールで、[BigQuery] ページに移動します。

    [BigQuery] に移動

  2. クエリエディタに次のクエリを貼り付け、[実行] をクリックします。

    SELECT *
    FROM
      ML.EVALUATE(MODEL `bqml_tutorial.baseline_taxi_tip_model`);

    結果は次のようになります。

    +---------------------+--------------------+------------------------+-----------------------+---------------------+---------------------+
    | mean_absolute_error | mean_squared_error | mean_squared_log_error | median_absolute_error |      r2_score       | explained_variance  |
    +---------------------+--------------------+------------------------+-----------------------+---------------------+---------------------+
    |  2.5853895559690323 | 23760.416358496139 |   0.017392406523370374 | 0.0044248227819481123 | -1934.5450533482465 | -1934.3513857946277 |
    +---------------------+--------------------+------------------------+-----------------------+---------------------+---------------------+
    

ベースライン モデルの r2_score 値は負の値で、データに適合していないことを示します。R2 スコアが 1 に近いほど、モデルの適合性は高くなります。

ハイパーパラメータ チューニングを使用して線形回帰モデルを作成する

ハイパーパラメータ チューニングを使用して線形回帰モデルを作成し、taxi_tip_input テーブルデータでトレーニングします。

CREATE MODEL ステートメントでは、次のハイパーパラメータ チューニング オプションを使用します。

  • NUM_TRIALS オプション: トライアル数を 20 に設定します。
  • MAX_PARALLEL_TRIALS オプション: 各トレーニング ジョブで 2 つのトライアルを実行し、合計 10 個のジョブと 20 個のトライアルを実行します。これにより、必要なトレーニング時間が短縮されます。ただし、2 つのトライアルを同時に実行する場合、互いのトレーニング結果による相乗効果は得られません。
  • L1_REG オプション: さまざまなトライアルで異なる L1 正則化値を試すことができます。L1 正則化では、モデルから無関係な特徴量が削除されるため、過学習を防ぐことができます。

モデルでサポートされている他のハイパーパラメータ チューニング オプションは、次のようにデフォルト値を使用します。

  • L1_REG: 0
  • HPARAM_TUNING_ALGORITHM: 'VIZIER_DEFAULT'
  • HPARAM_TUNING_OBJECTIVES: ['R2_SCORE']

次の手順でモデルを作成します。

  1. Google Cloud コンソールで、[BigQuery] ページに移動します。

    [BigQuery] に移動

  2. クエリエディタに次のクエリを貼り付け、[実行] をクリックします。

    CREATE OR REPLACE MODEL `bqml_tutorial.hp_taxi_tip_model`
      OPTIONS (
        MODEL_TYPE = 'LINEAR_REG',
        NUM_TRIALS = 20,
        MAX_PARALLEL_TRIALS = 2,
        L1_REG = HPARAM_RANGE(0, 5))
    AS
    SELECT
      *
    FROM
      `bqml_tutorial.taxi_tip_input`;

    このクエリが完了するまでに約 20 分かかります。

トレーニング トライアルに関する情報を取得する

ML.TRIAL_INFO 関数を使用して、ハイパーパラメータ値、目標、ステータスなど、すべてのトライアルに関する情報を取得します。この関数は、この情報に基づいて、パフォーマンスが最も優れているトライアルに関する情報も返します。

トライアル情報を取得する手順は次のとおりです。

  1. Google Cloud コンソールで、[BigQuery] ページに移動します。

    [BigQuery] に移動

  2. クエリエディタに次のクエリを貼り付け、[実行] をクリックします。

    SELECT *
    FROM
      ML.TRIAL_INFO(MODEL `bqml_tutorial.hp_taxi_tip_model`)
    ORDER BY is_optimal DESC;

    結果は次のようになります。

    +----------+-------------------------------------+-----------------------------------+--------------------+--------------------+-----------+---------------+------------+
    | trial_id |           hyperparameters           | hparam_tuning_evaluation_metrics  |   training_loss    |     eval_loss      |  status   | error_message | is_optimal |
    +----------+-------------------------------------+-----------------------------------+--------------------+--------------------+-----------+---------------+------------+
    |        7 |      {"l1_reg":"4.999999999999985"} |  {"r2_score":"0.653653627638174"} | 4.4677841296238165 |  4.478469742512195 | SUCCEEDED | NULL          |       true |
    |        2 |  {"l1_reg":"2.402163664510254E-11"} | {"r2_score":"0.6532493667964732"} |  4.457692508421795 |  4.483697081650438 | SUCCEEDED | NULL          |      false |
    |        3 |  {"l1_reg":"1.2929452948742316E-7"} |  {"r2_score":"0.653249366811995"} |   4.45769250849513 |  4.483697081449748 | SUCCEEDED | NULL          |      false |
    |        4 |  {"l1_reg":"2.5787102060628228E-5"} | {"r2_score":"0.6532493698925899"} |  4.457692523040582 |  4.483697041615808 | SUCCEEDED | NULL          |      false |
    |      ... |                             ...     |                           ...     |              ...   |             ...    |       ... |          ...  |        ... |
    +----------+-------------------------------------+-----------------------------------+--------------------+--------------------+-----------+---------------+------------+
    

    is_optimal 列の値は、トライアル 7 がチューニングによって返された最適なモデルであることを示しています。

チューニング済みモデルの試行を評価する

ML.EVALUATE 関数を使用してトライアルのパフォーマンスを評価します。ML.EVALUATE 関数は、モデルから返された予測コンテンツ評価を、すべてのトライアルのトレーニング中に計算された評価指標と比較して評価します。

トライアルを評価する手順は次のとおりです。

  1. Google Cloud コンソールで、[BigQuery] ページに移動します。

    [BigQuery] に移動

  2. クエリエディタに次のクエリを貼り付け、[実行] をクリックします。

    SELECT *
    FROM
      ML.EVALUATE(MODEL `bqml_tutorial.hp_taxi_tip_model`)
    ORDER BY r2_score DESC;

    結果は次のようになります。

    +----------+---------------------+--------------------+------------------------+-----------------------+--------------------+--------------------+
    | trial_id | mean_absolute_error | mean_squared_error | mean_squared_log_error | median_absolute_error |      r2_score      | explained_variance |
    +----------+---------------------+--------------------+------------------------+-----------------------+--------------------+--------------------+
    |        7 |   1.151814398002232 |  4.109811493266523 |     0.4918733252641176 |    0.5736103414025084 | 0.6652110305659145 | 0.6652144696114834 |
    |       19 |  1.1518143358927102 |  4.109811921460791 |     0.4918672150119582 |    0.5736106106914161 | 0.6652109956848206 | 0.6652144346901685 |
    |        8 |   1.152747850702547 |  4.123625876152422 |     0.4897808307399327 |    0.5731702310239184 | 0.6640856984144734 |  0.664088410199906 |
    |        5 |   1.152895108945439 |  4.125775524878872 |    0.48939088205957937 |    0.5723300569616766 | 0.6639105860807425 | 0.6639132416838652 |
    |      ... |                ...  |                ... |                    ... |                   ... |                ... |                ... |
    +----------+---------------------+--------------------+------------------------+-----------------------+--------------------+--------------------+
    

    最適なモデル(トライアル 7)の r2_score 値は 0.66521103056591446 で、ベースライン モデルよりも大幅に改善されています。

特定のトライアルを評価するには、ML.EVALUATE 関数で TRIAL_ID 引数を指定します。

ML.TRIAL_INFO 目標と ML.EVALUATE 評価指標の違いについて詳しくは、モデル提供関数をご覧ください。

チューニング済みモデルを使用してタクシーのチップ代を予測する

チューニングによって返された最適なモデルを使用して、さまざまなタクシーの乗車区間のチップ代を予測します。TRIAL_ID 引数を指定して別のトライアルを選択しない限り、最適なモデルが ML.PREDICT 関数によって自動的に使用されます。予測は predicted_label 列に返されます。

予測を取得する手順は次のとおりです。

  1. Google Cloud コンソールで、[BigQuery] ページに移動します。

    [BigQuery] に移動

  2. クエリエディタに次のクエリを貼り付け、[実行] をクリックします。

    SELECT *
    FROM
      ML.PREDICT(
        MODEL `bqml_tutorial.hp_taxi_tip_model`,
        (
          SELECT
            *
          FROM
            `bqml_tutorial.taxi_tip_input`
          LIMIT 5
        ));

    結果は次のようになります。

    +----------+--------------------+-----------+---------------------+---------------------+-----------------+---------------+-----------+--------------------+--------------+-------------+-------+---------+--------------+---------------+--------------+--------------------+---------------------+----------------+-----------------+-------+
    | trial_id |  predicted_label   | vendor_id |   pickup_datetime   |  dropoff_datetime   | passenger_count | trip_distance | rate_code | store_and_fwd_flag | payment_type | fare_amount | extra | mta_tax | tolls_amount | imp_surcharge | total_amount | pickup_location_id | dropoff_location_id | data_file_year | data_file_month | label |
    +----------+--------------------+-----------+---------------------+---------------------+-----------------+---------------+-----------+--------------------+--------------+-------------+-------+---------+--------------+---------------+--------------+--------------------+---------------------+----------------+-----------------+-------+
    |        7 |  1.343367839584448 | 2         | 2018-01-15 18:55:15 | 2018-01-15 18:56:18 |               1 |             0 | 1         | N                  | 1            |           0 |     0 |       0 |            0 |             0 |            0 | 193                | 193                 |           2018 |               1 |     0 |
    |        7 | -1.176072791783461 | 1         | 2018-01-08 10:26:24 | 2018-01-08 10:26:37 |               1 |             0 | 5         | N                  | 3            |        0.01 |     0 |       0 |            0 |           0.3 |         0.31 | 158                | 158                 |           2018 |               1 |     0 |
    |        7 |  3.839580104168765 | 1         | 2018-01-22 10:58:02 | 2018-01-22 12:01:11 |               1 |          16.1 | 1         | N                  | 1            |        54.5 |     0 |     0.5 |            0 |           0.3 |         55.3 | 140                | 91                  |           2018 |               1 |     0 |
    |        7 |  4.677393985230036 | 1         | 2018-01-16 10:14:35 | 2018-01-16 11:07:28 |               1 |            18 | 1         | N                  | 2            |        54.5 |     0 |     0.5 |            0 |           0.3 |         55.3 | 138                | 67                  |           2018 |               1 |     0 |
    |        7 |  7.938988937253062 | 2         | 2018-01-16 07:05:15 | 2018-01-16 08:06:31 |               1 |          17.8 | 1         | N                  | 1            |        54.5 |     0 |     0.5 |            0 |           0.3 |        66.36 | 132                | 255                 |           2018 |               1 | 11.06 |
    +----------+--------------------+-----------+---------------------+---------------------+-----------------+---------------+-----------+--------------------+--------------+-------------+-------+---------+--------------+---------------+--------------+--------------------+---------------------+----------------+-----------------+-------+