前回のトライ*1を拡張して,DAGMの全6カテゴリーの傷をU-Netで学習させてみます.
2カテゴリーで出来たのだから,6カテゴリーでも簡単にできるだろうと思ったのですが,案外苦労しました(汗)
入力画像
DAGM画像を↓↓の様に使いました.
Category | 学習 | テスト |
---|---|---|
Class1 | 1-120 | 121-150 |
Class2 | 1-120 | 121-150 |
Class3 | 1-120 | 121-150 |
Class4 | 1-120 | 121-150 |
Class5 | 1-120 | 121-150 |
Class6 | 1-120 | 121-150 |
ソース
学習データ作成: configurate_data.py
データが大きすぎてnumpy.savetxtが例外になってしまった(メモリ不足)為,ファイル出力関数を自作しています.
一行づつファイルオープン & クローズすることで,メモリー消費を減らしています.
当然,処理時間はかかります.
def savetxt(filename, data): for i in range(data.shape[0]): print(filename + ', line ' + str(i)) file = open(filename, 'a') file.write('{:.9f}'.format(data[i, 0])) for j in range(1, data.shape[1]): file.write(',' + '{:.9f}'.format(data[i, j])) file.write('\n') file.close()
学習: train.py
テスト画像が多くて(30枚 x 6カテゴリ)評価関数の計算でResource exhaustedになりました.
学習画像と同様に,特定数をランダムに抽出して計算するようにしました.
if __name__=='__main__': ... # display logs per step if epoch % display_step == 0: component = defineTestComtents() minitest_x, minitest_y = next_test_set(component) cost_val = sess.run(cost, feed_dict={x: minitest_x, y: minitest_y, input_size: batch_size})
結果
ネットワーク入出力
上段: 左から,入力画像,Class1のラベル,Class2のラベル,,,Class6のラベル
下段: 左から,Class1の検出結果,Class2の検出結果,,,Class6の検出結果
検出結果の拡大
上段から順に,Class1,Class2,,,Class6
- カテゴリ4が,背景と傷が似ていて少し難しそう
- (2種類の場合も同じだが)対照傷では無い傷をご検出する(例えば,カテゴリ2用のネットワークがカテゴリ1用のネットワークに反応してしまう)頻度はかなり高い
疑わしい部分を警告する意味では良いのですが,実運用には工夫が要ると思いました.
また,傷の種類を2から6に増やすのが結構大変でした.
リソースの工面にこれほど苦労するくらいなら,傷の種類ごとに完全に別ネットワークで学習させた方が楽だと思います.
前回のソースを更新する仕方で公開しています*2.
次は,傷無しの画像をネットワークに入れてみて,どの程度ご検出が起こるのか見てみようと思います.