オッサンはDesktopが好き

自作PCや機械学習、自転車のことを脈絡無く書きます

U-Netで6種類の傷を検出してみる

前回のトライ*1を拡張して,DAGMの全6カテゴリーの傷をU-Netで学習させてみます.
2カテゴリーで出来たのだから,6カテゴリーでも簡単にできるだろうと思ったのですが,案外苦労しました(汗)

入力画像

DAGM画像を↓↓の様に使いました.

Category 学習 テスト
Class1 1-120 121-150
Class2 1-120 121-150
Class3 1-120 121-150
Class4 1-120 121-150
Class5 1-120 121-150
Class6 1-120 121-150

ソース

学習データ作成: configurate_data.py

データが大きすぎてnumpy.savetxtが例外になってしまった(メモリ不足)為,ファイル出力関数を自作しています.
一行づつファイルオープン & クローズすることで,メモリー消費を減らしています.
当然,処理時間はかかります.

def savetxt(filename, data):
    for i in range(data.shape[0]):
        print(filename + ', line ' + str(i))
        file = open(filename, 'a')     
        file.write('{:.9f}'.format(data[i, 0]))
        for j in range(1, data.shape[1]):
            file.write(',' + '{:.9f}'.format(data[i, j]))
        file.write('\n')
        file.close()

学習: train.py

テスト画像が多くて(30枚 x 6カテゴリ)評価関数の計算でResource exhaustedになりました.
学習画像と同様に,特定数をランダムに抽出して計算するようにしました.

if __name__=='__main__':  

     ...

     # display logs per step
     if epoch % display_step == 0:
          component = defineTestComtents() 
          minitest_x, minitest_y = next_test_set(component)
          cost_val = sess.run(cost, feed_dict={x: minitest_x, y: minitest_y, input_size: batch_size})

結果

f:id:changlikesdesktop:20190826060507p:plain:w500
ネットワーク入出力
上段: 左から,入力画像,Class1のラベル,Class2のラベル,,,Class6のラベル
下段: 左から,Class1の検出結果,Class2の検出結果,,,Class6の検出結果

f:id:changlikesdesktop:20190826060522p:plain:w500
f:id:changlikesdesktop:20190826060537p:plain:w500
f:id:changlikesdesktop:20190826060549p:plain:w500
f:id:changlikesdesktop:20190826060618p:plain:w500
f:id:changlikesdesktop:20190826060628p:plain:w500
f:id:changlikesdesktop:20190826060645p:plain:w500
検出結果の拡大
上段から順に,Class1,Class2,,,Class6

  • カテゴリ4が,背景と傷が似ていて少し難しそう
  • (2種類の場合も同じだが)対照傷では無い傷をご検出する(例えば,カテゴリ2用のネットワークがカテゴリ1用のネットワークに反応してしまう)頻度はかなり高い

疑わしい部分を警告する意味では良いのですが,実運用には工夫が要ると思いました.
また,傷の種類を2から6に増やすのが結構大変でした.
リソースの工面にこれほど苦労するくらいなら,傷の種類ごとに完全に別ネットワークで学習させた方が楽だと思います.

前回のソースを更新する仕方で公開しています*2
次は,傷無しの画像をネットワークに入れてみて,どの程度ご検出が起こるのか見てみようと思います.