前回のトライ*1を拡張して,DAGMの全6カテゴリーの傷をU-Netで学習させてみます.
2カテゴリーで出来たのだから,6カテゴリーでも簡単にできるだろうと思ったのですが,案外苦労しました(汗)
入力画像
DAGM画像を↓↓の様に使いました.
| Category | 学習 | テスト |
|---|---|---|
| Class1 | 1-120 | 121-150 |
| Class2 | 1-120 | 121-150 |
| Class3 | 1-120 | 121-150 |
| Class4 | 1-120 | 121-150 |
| Class5 | 1-120 | 121-150 |
| Class6 | 1-120 | 121-150 |
ソース
学習データ作成: configurate_data.py
データが大きすぎてnumpy.savetxtが例外になってしまった(メモリ不足)為,ファイル出力関数を自作しています.
一行づつファイルオープン & クローズすることで,メモリー消費を減らしています.
当然,処理時間はかかります.
def savetxt(filename, data):
for i in range(data.shape[0]):
print(filename + ', line ' + str(i))
file = open(filename, 'a')
file.write('{:.9f}'.format(data[i, 0]))
for j in range(1, data.shape[1]):
file.write(',' + '{:.9f}'.format(data[i, j]))
file.write('\n')
file.close()
学習: train.py
テスト画像が多くて(30枚 x 6カテゴリ)評価関数の計算でResource exhaustedになりました.
学習画像と同様に,特定数をランダムに抽出して計算するようにしました.
if __name__=='__main__':
...
# display logs per step
if epoch % display_step == 0:
component = defineTestComtents()
minitest_x, minitest_y = next_test_set(component)
cost_val = sess.run(cost, feed_dict={x: minitest_x, y: minitest_y, input_size: batch_size})
結果

ネットワーク入出力
上段: 左から,入力画像,Class1のラベル,Class2のラベル,,,Class6のラベル
下段: 左から,Class1の検出結果,Class2の検出結果,,,Class6の検出結果






検出結果の拡大
上段から順に,Class1,Class2,,,Class6
- カテゴリ4が,背景と傷が似ていて少し難しそう
- (2種類の場合も同じだが)対照傷では無い傷をご検出する(例えば,カテゴリ2用のネットワークがカテゴリ1用のネットワークに反応してしまう)頻度はかなり高い
疑わしい部分を警告する意味では良いのですが,実運用には工夫が要ると思いました.
また,傷の種類を2から6に増やすのが結構大変でした.
リソースの工面にこれほど苦労するくらいなら,傷の種類ごとに完全に別ネットワークで学習させた方が楽だと思います.
前回のソースを更新する仕方で公開しています*2.
次は,傷無しの画像をネットワークに入れてみて,どの程度ご検出が起こるのか見てみようと思います.