オッサンはDesktopが好き

自作PCや機械学習、自転車のことを脈絡無く書きます

ディープ・ラーニング: 複数種類の異常検知におけるChannelの効能がスゴい

 こんにちは.changです.

 以前に6 ChannelのU-Netを構成して複数種類の異常を覚えさせたとき,「1 Channelで,傷の種類ごとに別のネットワークにした方が楽そうだ」と書きました*1. 今回はこれを掘り下げて実証してみようと思います.

 結論を先に言っておくと,予想とは裏腹に6 Channelの方が良いことが判りました.

f:id:changlikesdesktop:20200719110403p:plain:w600

0. Channelとは?

 ニューラル・ネットワークにおけるChannelとは,画像と垂直な方向に構成する層のことです (説明を楽にするために画像処理を前提としてお話しています). よくある使い方は,カラー画像のRGBをChannel方向に充てるやり方です. 先日作ったGAN*2*3もそうです.

f:id:changlikesdesktop:20200719111223p:plain:w150

 U-Netで複数種類のネットワークを学習させる場合,このChannelに異常毎の推論結果を出します. 異常が6種類あるならば6 Channelになり,入力に対して常に6種類ぶんの推論結果を返します.

f:id:changlikesdesktop:20200719112635p:plain:w600

 Channel方向には畳み込みやPoolingなどの強力な処理は作用しません. しかし,繋がっています. 異なる異常を推論するはずのネットワークが繋がっている,つまり依存関係を持っているのです. 学習の過程でChannel間の依存関係が切れ,異常毎に独立した動きをするようになる必要があります. ことことから,Channelを使った複数以上の検出は学習に時間を要すると予想できます.

1. DLアプリを作る上で考慮すべき点

 ディープ・ラーニング(以後,DL)を構築して,アプリに実装することを考えてみます. 多く場合,DLの学習は大容量GPUボードを搭載したLenuxコンピュータで行うでしょう. 一方,ユーザーが使うアプリはiOSや,AndroidWindows上で動きます. ですので,学習結果を運用する場合には,学習済みのネットワークを運用側の環境に移植し,運用側の環境で推論処理を書く必要があります. 今回はC#を使ったWindowsアプリを想定します.

 快適なアプリを作る為に,以下の3点を考慮することにします.

(1) 学習時間

 学習時間はアプリの使い勝手とは関係がありません. しかし,学習に天文学的な時間を要するようではそもそもアプリをリリース出来ません. 現実的な時間で学習が終わるネットワーク構成にする必要があります.

(2) 学習済みネットワークの容量

 学習済みのネットワークは,使用しているライブラリ(僕の場合はtensorflowとkeras)が吐き出すバイナリファイルになります. アプリ側でこれをロードするわけですが,容量が大きすぎるとメモリやGPUを消費し過ぎてしまいます. アプリ側でGPUを使うのは現実的ではないと思っていますので,今回はメモリ消費量で議論します.

(3) 推論に要する時間

 アプリが答えを出すのに時間がかかるとイライラします. 出来るだけ短い時間で推論が終わるようにします.

(4) 異常検知の正確さ

 一番大切な事ですが,異常検知の精度を重視します. 

2. データセット

 DAGMデータセット*4の異常画像のみを使います.

f:id:changlikesdesktop:20200719095620p:plain:w400
DAGMデータセット
左上から順にClass_1,Class_2,Class_3,Class_4,Class_5,Class_6の異常画像.
各クラスについて,150枚の異常画像が用意されている.

3. プログラム

 これまでに書いてきたプログラムの応用なので,特に変わったことはしていません. 唯一,今回はC#を使ってビューワーを作ってみました.

f:id:changlikesdesktop:20200719085500p:plain:w400
推論結果のビューワー
PictureBoxに画像をドラッグ&ドロップ後にボタンを押すと,推論結果が表示される.
ListBoxには計算時間を表示している.

4. 結果

(1) 学習時間

 バッチサイズ8,エポック200の同条件で学習を行った計算時間を比較します.

Class 計算時間[s]
1 794
2 798
3 800
4 801
5 789
6 797
all 5488

 1 Channelで学習させる場合,1クラスあたりおよそ800 秒(13.3分)かかりました. 対して,6 Channelで全てのクラスを同時に学習させた場合は約5500 秒(91.6分)でした. 1 Channelの学習時間の6倍よりも6 Channelの学習時間の方が長いですが,その差はわずか(12分)でした.

 学習の速さも,変わりませんでした. 1 Channelの方が直ぐに正解率が向上すると予想していたのですが,テストデータに対する正解率は1 Channelでも6 Channelでも大凡50 epochで飽和しました.

f:id:changlikesdesktop:20200719115348p:plain:w400
テストデータに対する正解率の推移

(2) 学習済みネットワークの容量

 1 Channelの場合も6 Channelの場合も殆ど変わりませんでした. 若干(7KB),6Channel版の方が大きいですが,誤差です.

Class 学習済みネットワークの容量[KB]
1 404.469
2 404.469
3 404.469
4 404.469
5 404.469
6 404.469
all 404.506

 実際にモデルをロードした時のメモリ消費量は正確には測れませんでしたが,タスクマネージャーで見る限りでは1 Channelでも6 Channelでも変わらず,600 MB位でした. Channelが増えるとネットワークが大きくなるような気がするのですが,以外です. 1 Channelのネットワークを6個ロードすると,3.6 GByteのメモリを消費します. 6 Channelの方が圧倒的に有利だと言えます.

(3) 推論に要する時間

 統計的に調べてはいませんが,僕のPC環境では一律6秒位でした. 6 Channelの方が少しだけ時間が掛かりましたが,誤差レベル(数10 ms)でした.

(4) 異常検知の正確さ

 Class_1の異常画像を,1 Channelで学習した各Classのネットワークで推論してみると,Class_5とClass_6でも異常と判定されました.

f:id:changlikesdesktop:20200719093912p:plain:w600

 一方,6 Channelの学習済みネットワークでは,ちゃんとClass_1のみが異常と判定されました.

f:id:changlikesdesktop:20200719093932p:plain:w400

 全Classを総合すると,6 Channelの方が良いという事ですね.

5. 考察

Channel数 学習時間 容量 推論時間 推論精度
1
6

 予想と大きく反する結果になりました... 少なくとも推論の精度は1 Channelの方が優れていると思ったのですが,他のクラスで誤検出し易いという落とし穴がありました. 考えてみれば当然ですね. 1 Channelで,他Classの画像を異常無しとして学習させる方法も考えられますが,異常を強調するための重み付け*5など,煩雑な処理が必要になると予想されます.

 1 Channelが優れている点を挙げるとすれば,拡張性でしょうか. 解析する異常の種類を後から追加する場合,元々の異常の種類と同じChannel数でネットワークを組んでしまっていると始めからやり直しになります. Channel数に余裕を持っていれば,追加する異常について追加学習をすることも出来ると思いますが,,,調査が必要です. 現状の知見で判断すると,システムの仕様が固まって検知する異常の種類が固まっていれば複数チャンネル,その前段階の下調べでは単一Channelを使うのが良いと思います.

6. むすび

 新たな発見があって楽しめました. 1 Channelと6 Channelで推論時間が変わらないと言うのはどういう理屈なのか,,,自分の知識不足で理解できません(汗). 引き続き探求していこうと思います.

 今回書いたソースはここです*6