オッサンはDesktopが好き

自作PCや機械学習、自転車のことを脈絡無く書きます

U-Netをkerasで書くと一寸大変だった,という話

 kerasが使い易いので本格的に乗り換えようと思い,画像処理のメインツールとして使ってきたU-Netを移植しようとしたら,案外苦労したという話です. 以前に,tensorflow単独で書いたDAGMの異常検知*1を,keras + tensorflowで書き直します.

0. Sequentialが使えない

 kerasの(新しめの?)サンプルを見るとSequentialでネットワークを書いている事が多いと思います. 実際,mnist*2もSequentialで構築しました. ところが,連結(concatenate)などの複雑なネットワークを組もうとすると,Sequentialが使えません(データの流れがSequentialでは無いため). このため,inputとoutputを指定してオーソドックスにモデルを書く必要がありました.

model.py

inputs = Input(self.input_size)

# encoding ##############
conv1_1 = Conv2D(64, 3, activation="relu", padding="same", kernel_initializer="he_normal", bias_initializer="zeros")(inputs)
conv1_2 = Conv2D(64, 3, activation="relu", padding="same", kernel_initializer="he_normal", bias_initializer="zeros")(conv1_1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1_2)

conv2_1 = Conv2D(128, 3, activation="relu", padding="same", kernel_initializer="he_normal", bias_initializer="zeros")(pool1)
conv2_2 = Conv2D(128, 3, activation="relu", padding="same", kernel_initializer="he_normal", bias_initializer="zeros")(conv2_1)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2_2)

## 省略 ##

concated6 = concatenate([conv3_2, conv_up6], axis=3)

conv7_1 = Conv2D(256, 3, activation="relu", padding="same", kernel_initializer="he_normal", bias_initializer="zeros")(concated6)
conv7_2 = Conv2D(256, 3, activation="relu", padding="same", kernel_initializer="he_normal", bias_initializer="zeros")(conv7_1)
conv_up7 = Conv2D(128, 3, activation="relu", padding="same", kernel_initializer="he_normal", bias_initializer="zeros")(UpSampling2D(size=(2, 2))(conv7_2))
concated7 = concatenate([conv2_2, conv_up7], axis=3)

conv8_1 = Conv2D(128, 3, activation="relu", padding="same", kernel_initializer="he_normal", bias_initializer="zeros")(concated7)
conv8_2 = Conv2D(128, 3, activation="relu", padding="same", kernel_initializer="he_normal", bias_initializer="zeros")(conv8_1)
conv_up8 = Conv2D(64, 3, activation="relu", padding="same", kernel_initializer="he_normal", bias_initializer="zeros")(UpSampling2D(size=(2, 2))(conv8_2))
concated8 = concatenate([conv1_2, conv_up8], axis=3)
conv9_1 = Conv2D(64, 3, activation="relu", padding="same", kernel_initializer="he_normal", bias_initializer="zeros")(concated8)
conv9_2 = Conv2D(64, 3, activation="relu", padding="same", kernel_initializer="he_normal", bias_initializer="zeros")(conv9_1)
outputs = Conv2D(1, 1, activation="sigmoid")(conv9_2)

self.model = Model(input=inputs, output=outputs)

1. 適切な損失関数を選ぶ,或いは関数に合わせたデータ形式にする必要がある

 tensorflow単独に書いたソースでは,ラベルがfloat型になっていました. 傷の大きさの違いの影響を避けるために,傷の面積でラベルを正規化していたためです.

 このラベルが,keras+tensorflowのbinary_cross_entropyに適合しないというのが,今回,一番手こずったところです. 計算自体は回ってしまうので,何が悪いのか理解するのに時間がかかりました. keras+tensorflowで使用できる損失関数については公式サイト*3にまとめられていますが,それぞれの関数の詳細までは記載がありません. "binaryなんだから1 or 0”というのは,気づいてしまえば当たり前なんですが...(泣)

 修正方法としては,ラベルをnumpy.boolに型指定して構築しました.

load_data.py

def read_labels(self, filename, DATA_SIZE):
    LblOrg = np.loadtxt(filename, delimiter=',')
    Labels = np.zeros((DATA_SIZE, self.IMG_SIZE, self.IMG_SIZE, 1), dtype=np.bool)
    for i in range(DATA_SIZE):
        Labels[i, :, :, 0] = LblOrg[i, 1:self.IMG_SIZE*self.IMG_SIZE + 1].reshape(self.IMG_SIZE, self.IMG_SIZE)
    Labels = Labels.astype('bool')
    return Labels

 tensorflow単独と,keras+tensoflowで完全に同じ動きをさせたかったのですが,今回は断念です(>_<)

2. tensorflow単独と学習結果が異なる

f:id:changlikesdesktop:20200520192037p:plain:w600
左から,元画像,ラベル画像,tensorflow単独での推論結果,keras + tensorflowでの推論結果

 ネットワーク構造が同じであるにもかかわらず,tensorflow単独とkeras+tensorflowとで大きく異なる学習結果になりました. tensorflow単独では,ラベル領域の中に含まれる異常部分のみが抽出されています. 一方,keras + tensorflowは,画像の異常というよりも,ラベルの付け方の癖を学習した様に見えます. 上述した様にラベルの定義など,プログラム上の違いはあるのですが,ここまでの差異が生じる理由が解りません.

3. keras+tensorflowがやたらメモリを食う

 tensorflow単独のソースでは,自作のバッチ処理で20枚の画像をネットワークに入力して学習させていました. keras + tensorflowでbatch_size=20を指定するとメモリ不足でエラーになってしまいました. GPU領域を乱暴に取りに行くのかな...? メモリのとり方をネットワーク構造やデータ量に合わせて設計するべきですが,今回は試行錯誤的にbatch_size=8としました.`

4. まとめ

 ライブラリ依存の開発をすると,こういう事が起きますよね. ライブラリの内部処理を探求する(し過ぎる)ことは無意味だと思っているので,癖を掴んで取捨選択しながら使うことになります. kerasを使うことでコーディング作業が楽になることは間違いありません. その反面,ブラックボックスが大きくなるという事ですね. kerasのインターフェース部分だけをうまいこと使っていく様なテクニックが必要になるのかも知れないです.

 今回書いたソースはここ*4です.