Kerasを用いた複数入力モデル精度向上のためのTips

2018/09/25
このエントリーをはてなブックマークに追加

はじめに

カブクで機械学習エンジニアをしている大串正矢です。今回は複数入力モデルの精度向上のためのTipsについて書きます。

背景

複数入力のモデルでは単一入力のモデルとは異なり、下記のような問題点があります。
– データによってロスに対する貢献度が異なり、ロスが下がりやすいデータを優先して学習してしまう。
– 学習の収束性はデータによって異なり、全体の最適化を目指すと個々のデータで過学習が発生してしまう。

1の問題を解決する方法として重み付きヘテロジニアスラーニング、2の問題を解決するためTask Wise Early Stoppingという手法が存在するので今回はその手法を紹介します。

  • 注意点:両手法はマルチタスクラーニングに適用される手法のため、記事上ではタスクと表記しています。1入力を1タスクと置き換えて記述しています。

データ、前処理、モデル定義部分

データと前処理、モデル定義部分は前回の記事と同一なので省略します。

重み付きヘテロジニアスラーニング

重み付きヘテロジニアスラーニングは事前学習し、タスクごとのクラス重みを導出した後に重みを設定して再度学習する手法です。

重みの導出方法は下記になります。

  • タスク\(t\)の安定度を導出(検証データのロスの平均\(\mu_t\)、検証データのロスの標準偏差\(\sigma_t\) )
    \begin{align}
    N_t = \mu_t + 3 * \sigma_t
    \end{align}

  • 基準タスクの選択

    • \(N_t\)が最小のタスクを\(N_f\)とする
  • 基準タスクとの安定度の比から重みを算出
    \begin{align}
    w_t = N_f / N_t
    \end{align}

コードで処理するには下記のようになります。
コールバック関数を設定して学習の終了時に検証データのロスに対する平均と標準偏差を計算します。今回は下記のようなコールバック関数を作成しました。

from keras.callbacks import Callback
import numpy as np
from collections import OrderedDict


class HeteroGeniousCallbacks(Callback):

    def __init__(self,
                 variable_number: int = 2,
                 ):
        super(HeteroGeniousCallbacks, self).__init__()

        self.current_val_loss = {}
        self.class_weight = {}
        self.variable_number = variable_number

Epoch後の検証データのロスデータを貯める処理が下記になります。

    def on_epoch_end(self, epoch, logs=None):

        sort_logs = OrderedDict(
            sorted(logs.items(), key=lambda x: x[0]))
        for each_label, each_values in sort_logs.items():
            if 'val' in each_label and 'loss' in each_label:
                if each_label not in self.class_weight:
                    self.class_weight[each_label] = each_values
                else:
                    each_values_tmp = self.class_weight[each_label] 
                    each_value_list = np.vstack((each_values,
                                                 each_values_tmp))
                    self.class_weight[each_label] = each_value_list

学習終了後に基準タスクとの安定度の比から重みを導出します。

    def on_train_end(self, logs=None):
        self.class_weight = OrderedDict(
            sorted(self.class_weight.items(), key=lambda x: x[0]))

        val_stable_dict = {}
        for each_label, each_value_lsit in self.class_weight.items():
            val_stable_dict[each_label] = \
                np.average(each_value_lsit) + 3.0 * np.std(each_value_lsit)

        val_stable_dict = OrderedDict(
            sorted(val_stable_dict.items(), key=lambda x: x[1]))

        most_stable_value = [value for value in val_stable_dict.values()][0]

        index = 0
        tmp_class_weight = {}

        for each_label in self.class_weight.keys():
            tmp_class_weight[index] = \
                val_stable_dict[each_label] / most_stable_value
            if index >= self.variable_number - 1:
                break
            index += 1

        self.class_weight = tmp_class_weight
        print('hetero genious class weight {}'.format(self.class_weight))

実際に使用するには下記のように行います。
コールバック関数を指定して重みを計算

hetero_genious_callbacks = HeteroGeniousCallbacks()
model.fit(x, x, validation_split=0.1, epochs=1000, callbacks=[hetero_genious_callbacks])

コールバック関数で導出された重みを設定して再学習

model.fit(x, x, validation_split=0.1, epochs=1000, 
          class_weight=hetero_genious_callbacks.class_weight,)

Task wise Early Stopping

この手法はタスクごとに学習の収束性が異なるため、タスクごとに学習が収束したら学習を止める手法になります。手法はシンプルなのですがDefine and RUN型のフレームワークで実装している場合は学習中の柔軟な変更が難しく、少し強引に実現しました。

流れとしては下記のようになっています。

1 観測するタスクを指定
2 タスクのロスが収束してきたら学習を止める
3 タスクに関連するレイヤーをフリーズして再度、学習を開始
4 2と3を繰り返し、1で指定したタスク全てが収束すれば終了

注意点として学習中にレイヤーをフリーズしても反映されず、再度compile後に学習をする必要があります。

コードで実現するには下記のようになります。ほぼKerasで標準実装されているEarlyStopと同一ですがどのロスを見て止まったか把握する必要があります。
初期設定部分で観測していたどの値で止まったかを確認するための変数を定義します。self.stop_monitorがその部分に当たります。

class TaskWiseEarlyStopping(Callback):

    def __init__(self,
                 monitor='val_loss',
                 min_delta=0,
                 patience=0,
                 verbose=0,
                 mode='auto',
                 baseline=None,
                 restore_best_weights=False):
        super(TaskWiseEarlyStopping, self).__init__()

        self.monitor = monitor
        self.baseline = baseline
        self.patience = patience
        self.verbose = verbose
        self.min_delta = min_delta
        self.wait = 0
        self.stopped_epoch = 0
        self.restore_best_weights = restore_best_weights
        self.best_weights = None
        self.stop_monitor = None

epochごとに観測している値が向上していない場合に学習を止めます。self.stop_monitor = self.monitorでどの値で止まったをチェックします。

    def on_epoch_end(self, epoch, logs=None):
        current = self.get_monitor_value(logs)
        if current is None:
            return

        if self.monitor_op(current - self.min_delta, self.best):
            self.best = current
            self.wait = 0
            if self.restore_best_weights:
                self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                self.stop_monitor = self.monitor
                if self.restore_best_weights:
                    if self.verbose > 0:
                        print('Restoring model weights from the end of '
                              'the best epoch')
                    self.model.set_weights(self.best_weights)

上記までがTaskWiseEarlyStoppingのコールバック関数に関してです。ここからはこのコールバック関数に実際に値を設定して処理を行います。
モデルの出力値のどの部分をチェックするかは下記のコードで設定します。今回のモデルはdenseが出力になるのでdenseの末尾にlossを加え監視項目としています。

monitor_list = ['val_' + layer.name + '_loss' for layer in model.layers if 'dense' in layer.name]

callbacks_list = [TaskWiseEarlyStopping(monitor=monitor, patience=30, verbose=1) for monitor in monitor_list]

EarlyStopが適用された後で、どのレイヤーがストップしたかを確認します。設定したTaskWiseEarlyStopのコールバック群の中で学習を中止したものがあればそのstop_monitorNoneではなく観測していた出力になるのでその値を取得します。

stop_layer = ''
stop_loss = ''
for callback in callbacks_list:
    if callback.stop_monitor is not None:
        stop_layer = callback.stop_monitor.replace('_loss', '')
        stop_loss = callback.stop_monitor

学習を中止した出力に関係するレイヤーは次回の学習では学習しないようにレイヤーをフリーズ(学習を行わない)処理します。

stop_input_layer = ''
for layer in model.layers:
    if layer.name == stop_layer:
        layer.trainable = False
        stop_input_layer = layer.input.name.split('/')[2]

for layer in model.layers:
    if layer.name == stop_input_layer:
        layer.trainable = False

compile処理によってフリーズしたレイヤーを適用し、model.summary()で学習するパラメータが減っているか確認します。

model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mse'])
model.summary()

下記が確認した結果です。学習しないパラメータ数が分かるNon-trainable paramsが0でないことが確認できます。

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_9 (InputLayer)            (None, 3, 1)         0                                            
__________________________________________________________________________________________________
input_10 (InputLayer)           (None, 3, 1)         0                                            
__________________________________________________________________________________________________
lstm_9 (LSTM)                   (None, 3, 120)       58560       input_9[0][0]                    
__________________________________________________________________________________________________
lstm_10 (LSTM)                  (None, 3, 150)       91200       input_10[0][0]                   
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 3, 1)         121         lstm_9[0][0]                     
__________________________________________________________________________________________________
dense_10 (Dense)                (None, 3, 1)         151         lstm_10[0][0]                    
==================================================================================================
Total params: 150,032
Trainable params: 91,351
Non-trainable params: 58,681
__________________________________________________________________________________________________

上記の結果だけでなくロスが変更されていないか確認も必要です。下記はdense_9_lossがフリーズされて学習が進まない結果です。dense_9_lossが変更されていないことが確認できます。

Epoch 1/800
150/150 [==============================] - 1s 3ms/step - loss: 3.8935e-04 - dense_9_loss: 1.6028e-07 - dense_10_loss: 3.8919e-04 - dense_9_mean_squared_error: 1.6028e-07 - dense_10_mean_squared_error: 3.8919e-04
Epoch 2/800
150/150 [==============================] - 0s 168us/step - loss: 1.9313e-04 - dense_9_loss: 1.6028e-07 - dense_10_loss: 1.9297e-04 - dense_9_mean_squared_error: 1.6028e-07 - dense_10_mean_squared_error: 1.9297e-04
Epoch 3/800
150/150 [==============================] - 0s 162us/step - loss: 1.3634e-04 - dense_9_loss: 1.6028e-07 - dense_10_loss: 1.3618e-04 - dense_9_mean_squared_error: 1.6028e-07 - dense_10_mean_squared_error: 1.3618e-04
Epoch 4/800
150/150 [==============================] - 0s 166us/step - loss: 8.3967e-05 - dense_9_loss: 1.6028e-07 - dense_10_loss: 8.3807e-05 - dense_9_mean_squared_error: 1.6028e-07 - dense_10_mean_squared_error: 8.3807e-05
Epoch 5/800
150/150 [==============================] - 0s 167us/step - loss: 5.1792e-05 - dense_9_loss: 1.6028e-07 - dense_10_loss: 5.1631e-05 - dense_9_mean_squared_error: 1.6028e-07 - dense_10_mean_squared_error: 5.1631e-05
Epoch 6/800
150/150 [==============================] - 0s 167us/step - loss: 3.2854e-05 - dense_9_loss: 1.6028e-07 - dense_10_loss: 3.2694e-05 - dense_9_mean_squared_error: 1.6028e-07 - dense_10_mean_squared_error: 3.2694e-05
:

結果

実験条件

  • 実行環境
    • OS: macOS Sierra
    • CPU: 2.9 GHz Intel Core i7
    • メモリー: 16 GB 2133 MHz LPDDR3
  • pythonバージョン
    • 3.6.0
  • ライブラリ

numpy==1.15.1
ipython==6.0.0
notebook==5.0.0
pandas==0.23.4
matplotlib==2.0.1
lxml==4.2.0
beautifulsoup4==4.6.0
scikit-learn==0.18.1
scipy==1.1.0
keras==2.2.2
tensorflow==1.8.0
  • モデル

  • LSTM

    • ノード数:120, 個別入力は気温変化が120, ガスの生産量は150
    • 他のパラメータはKearasで提供されているデフォルト値
  • Optimiser
    • adam
    • 他のパラメータはKearasで提供されているデフォルト値
  • epoch
    • 1000
  • EarlyStop(過学習を抑え、学習時間を短縮するために導入)
    • 検証データのロスが10epoch改善しない場合は強制的に止める処理

重み付きヘテロジニアスラーニング

重み付きヘテロジニアスラーニングの効果を確認します。

今回のデータでは事前学習の結果、入力0のデータが基準データとなり入力1のデータの学習が上手く進んでいないので導出された重みは下記になります。

{0: 1.0, 1: 9.7779}

これが適用された際の検証データにおけるロスがどのように変化しているかを見てみます。青が重み付きヘテロジニアスラーニングを適用する前で赤が適用後になります。
結果より適用後の方がロスが僅かに下がっており、入力0に対しては重要性を下げたことによる過学習の抑止、入力1に対しては重要度を上げたことによる学習の促進につながったと推測できます。

  • 入力0のケース

  • 入力1のケース

テストデータのRMSE

テストデータのRMSEを何も適用していないケースと重み付きヘテロジニアスラーニング、TaskWiseEarlyStoppingを適用した場合の値を比較します。
重み付きヘテロジニアスラーニングを適用した場合が最もRMSEが低く効果が高いことが伺えます。

各手法 RMSE
何もなし 0.0907
重み付きヘテロジニアスラーニング 0.0782
TaskWiseEarlyStopping 0.0829

全体のコード

最後に

弊社では標準的な実装例では解決できないような問題を解決できるスキルの高いエンジニアも絶賛採用中なので是非、弊社へ応募してください。

参考

https://www.slideshare.net/Takayosi/miru2018-tutorial-108675245/93

その他の記事

Other Articles

2022/06/03
拡張子に Web アプリを関連付ける File Handling API の使い方

2022/03/22
<selectmenu> タグできる子; <select> に代わるカスタマイズ可能なドロップダウンリスト

2022/03/02
Java 15 のテキストブロックを横目に C# 11 の生文字列リテラルを眺めて ECMAScript String dedent プロポーザルを想う

2021/10/13
Angularによる開発をできるだけ型安全にするためのKabukuでの取り組み

2021/09/30
さようなら、Node.js

2021/09/30
Union 型を含むオブジェクト型を代入するときに遭遇しうるTypeScript型チェックの制限について

2021/09/16
[ECMAScript] Pipe operator 論争まとめ – F# か Hack か両方か

2021/07/05
TypeScript v4.3 の機能を使って immutable ライブラリの型付けを頑張る

2021/06/25
Denoでwasmを動かすだけの話

2021/05/18
DOMMatrix: 2D / 3D 変形(アフィン変換)の行列を扱う DOM API

2021/03/29
GoのWASMがライブラリではなくアプリケーションであること

2021/03/26
Pythonプロジェクトの共通のひな形を作る

2021/03/25
インラインスタイルと Tailwind CSS と Tailwind CSS 入力補助ライブラリと Tailwind CSS in JS

2021/03/23
Serverless NEGを使ってApp Engineにカスタムドメインをワイルドカードマッピング

2021/01/07
esbuild の機能が足りないならプラグインを自作すればいいじゃない

2020/08/26
TypeScriptで関数の部分型を理解しよう

2020/06/16
[Web フロントエンド] esbuild が爆速すぎて webpack / Rollup にはもう戻れない

2020/03/19
[Web フロントエンド] Elm に心折れ Mint に癒しを求める

2020/02/28
さようなら、TypeScript enum

2020/02/14
受付のLooking Glassに加えたひと工夫

2020/01/28
カブクエンジニア開発合宿に行ってきました 2020冬

2020/01/30
Renovateで依存ライブラリをリノベーションしよう 〜 Bitbucket編 〜

2019/12/27
Cloud Tasks でも deferred ライブラリが使いたい

2019/12/25
*, ::before, ::after { flex: none; }

2019/12/21
Top-level awaitとDual Package Hazard

2019/12/20
Three.jsからWebGLまで行きて帰りし物語

2019/12/18
Three.jsに入門+手を検出してAR.jsと組み合わせてみた

2019/12/04
WebXR AR Paint その2

2019/11/06
GraphQLの入門書を翻訳しました

2019/09/20
Kabuku Connect 即時見積機能のバックエンド開発

2019/08/14
Maker Faire Tokyo 2019でARゲームを出展しました

2019/07/25
夏休みだョ!WebAssembly Proposal全員集合!!

2019/07/08
鵜呑みにしないで! —— 書籍『クリーンアーキテクチャ』所感 ≪null 篇≫

2019/07/03
W3C Workshop on Web Games参加レポート

2019/06/28
TypeScriptでObject.assign()に正しい型をつける

2019/06/25
カブクエンジニア開発合宿に行ってきました 2019夏

2019/06/21
Hola! KubeCon Europe 2019の参加レポート

2019/06/19
Clean Resume きれいな環境できれいな履歴書を作成する

2019/05/20
[Web フロントエンド] 状態更新ロジックをフレームワークから独立させる

2019/04/16
C++のenable_shared_from_thisを使う

2019/04/12
OpenAPI 3 ファーストな Web アプリケーション開発(Python で API 編)

2019/04/08
WebGLでレイマーチングを使ったCSGを実現する

2019/03/29
その1 Jetson TX2でk3s(枯山水)を動かしてみた

2019/04/02
『エンジニア採用最前線』に感化されて2週間でエンジニア主導の求人票更新フローを構築した話

2019/03/27
任意のブラウザ上でJestで書いたテストを実行する

2019/02/08
TypeScript で “radian” と “degree” を間違えないようにする

2019/02/05
Python3でGoogle Cloud ML Engineをローカルで動作する方法

2019/01/18
SIGGRAPH Asia 2018 参加レポート

2019/01/08
お正月だョ!ECMAScript Proposal全員集合!!

2019/01/08
カブクエンジニア開発合宿に行ってきました 2018秋

2018/12/25
OpenAPI 3 ファーストな Web アプリケーション開発(環境編)

2018/12/23
いまMLKitカスタムモデル(TF Lite)は使えるのか

2018/12/21
[IoT] Docker on JetsonでMQTTを使ってCloud IoT Coreと通信する

2018/12/11
TypeScriptで実現する型安全な多言語対応(Angularを例に)

2018/12/05
GASでCompute Engineの時間に応じた自動停止/起動ツールを作成する 〜GASで簡単に好きなGoogle APIを叩く方法〜

2018/12/02
single quotes な Black を vendoring して packaging

2018/11/14
3次元データに2次元データの深層学習の技術(Inception V3, ResNet)を適用

2018/11/04
Node Knockout 2018 に参戦しました

2018/10/24
SIGGRAPH 2018参加レポート-後編(VR/AR)

2018/10/11
Angular 4アプリケーションをAngular 6に移行する

2018/10/05
SIGGRAPH 2018参加レポート-特別編(VR@50)

2018/10/03
Three.jsでVRしたい

2018/10/02
SIGGRAPH 2018参加レポート-前編

2018/09/27
ズーム可能なSVGを実装する方法の解説

2018/09/21
競技プログラミングの勉強会を開催している話

2018/09/19
Ladder Netwoksによる半教師あり学習

2018/08/10
「Maker Faire Tokyo 2018」に出展しました

2018/08/02
Kerasを用いた複数時系列データを1つの深層学習モデルで学習させる方法

2018/07/26
Apollo GraphQLでWebサービスを開発してわかったこと

2018/07/19
【深層学習】時系列データに対する1次元畳み込み層の出力を可視化

2018/07/11
きたない requirements.txt から Pipenv への移行

2018/06/26
CSS Houdiniを味見する

2018/06/25
不確実性を考慮した時系列データ予測

2018/06/20
Google Colaboratory を自分のマシンで走らせる

2018/06/18
Go言語でWebAssembly

2018/06/15
カブクエンジニア開発合宿に行ってきました 2018春

2018/06/08
2018 年の tree shaking

2018/06/07
隠れマルコフモデル 入門

2018/05/30
DASKによる探索的データ分析(EDA)

2018/05/10
TensorFlowをソースからビルドする方法とその効果

2018/04/23
EGLとOpenGLを使用するコードのビルド方法〜libGLからlibOpenGLへ

2018/04/23
技術書典4にサークル参加してきました

2018/04/13
Python で Cura をバッチ実行するためには

2018/04/04
ARCoreで3Dプリント風エフェクトを実現する〜呪文による積層造形映像制作の舞台裏〜

2018/04/02
深層学習を用いた時系列データにおける異常検知

2018/04/01
音声ユーザーインターフェースを用いた新方式積層造形装置の提案

2018/03/31
Container builderでコンテナイメージをBuildしてSlackで結果を受け取る開発スタイルが捗る

2018/03/23
ngUpgrade を使って AngularJS から Angular に移行

2018/03/14
Three.jsのパフォーマンスTips

2018/02/14
C++17の新機能を試す〜その1「3次元版hypot」

2018/01/17
時系列データにおける異常検知

2018/01/11
異常検知の基礎

2018/01/09
three.ar.jsを使ったスマホAR入門

2017/12/17
Python OpenAPIライブラリ bravado-core の発展的な使い方

2017/12/15
WebAssembly(wat)を手書きする

2017/12/14
AngularJS を Angular に移行: ng-annotate 相当の機能を TypeScrpt ファイルに適用

2017/12/08
Android Thingsで4足ロボットを作る ~ Android ThingsとPCA9685でサーボ制御)

2017/12/06
Raspberry PIとDialogflow & Google Cloud Platformを利用した、3Dプリンターボット(仮)の開発 (概要編)

2017/11/20
カブクエンジニア開発合宿に行ってきました 2017秋

2017/10/19
Android Thingsを使って3Dプリント戦車を作ろう ① ハードウェア準備編

2017/10/13
第2回 魁!! GPUクラスタ on GKE ~PodからGPUを使う編~

2017/10/05
第1回 魁!! GPUクラスタ on GKE ~GPUクラスタ構築編~

2017/09/13
「Maker Faire Tokyo 2017」に出展しました。

2017/09/11
PyConJP2017に参加しました

2017/09/08
bravado-coreによるOpenAPIを利用したPythonアプリケーション開発

2017/08/23
OpenAPIのご紹介

2017/08/18
EuroPython2017で2名登壇しました。

2017/07/26
3DプリンターでLチカ

2017/07/03
Three.js r86で何が変わったのか

2017/06/21
3次元データへの深層学習の適用

2017/06/01
カブクエンジニア開発合宿に行ってきました 2017春

2017/05/08
Three.js r85で何が変わったのか

2017/04/10
GCPのGPUインスタンスでレンダリングを高速化

2017/02/07
Three.js r84で何が変わったのか

2017/01/27
Google App EngineのFlexible EnvironmentにTmpfsを導入する

2016/12/21
Three.js r83で何が変わったのか

2016/12/02
Three.jsでのクリッピング平面の利用

2016/11/08
Three.js r82で何が変わったのか

2016/12/17
SIGGRAPH 2016 レポート

2016/11/02
カブクエンジニア開発合宿に行ってきました 2016秋

2016/10/28
PyConJP2016 行きました

2016/10/17
EuroPython2016で登壇しました

2016/10/13
Angular 2.0.0ファイナルへのアップグレード

2016/10/04
Three.js r81で何が変わったのか

2016/09/14
カブクのエンジニアインターンシッププログラムについての詩

2016/09/05
カブクのエンジニアインターンとして3ヶ月でやった事 〜高橋知成の場合〜

2016/08/30
Three.js r80で何が変わったのか

2016/07/15
Three.js r79で何が変わったのか

2016/06/02
Vulkanを試してみた

2016/05/20
MakerGoの作り方

2016/05/08
TensorFlow on DockerでGPUを使えるようにする方法

2016/04/27
Blenderの3DデータをMinecraftに送りこむ

2016/04/20
Tensorflowを使ったDeep LearningにおけるGPU性能調査

→
←

関連職種

Recruit

→
←

お客様のご要望に「Kabuku」はお応えいたします。
ぜひお気軽にご相談ください。

お電話でも受け付けております
03-6380-2750
営業時間:09:30~18:00
※土日祝は除く