こんにちは.changです. kerasでマルチクラスのU-Netを書くとちょっと変だな,と以前から気になっていました. 今回,それが解決したので記録します. 結果だけを言うと,ver.2になって不具合(?)が治ったという話です.
1. Tensorflow ver.1 + kerasによるU-Net
個人的にU-Netを良く使うのですが,Tensorflow単独での記述からkerasに移行するときに少し苦労しました*1. また,マルチクラス化したとき*2,多少無理矢理にビルドを通した経緯がありました.
問題になっているのは下記の箇所です.githubにソースが上がっています*3.
model.py
class MyModel: def __init__(self, input_size, batch_size, epochs): self.input_size = input_size ...(省略)... def create_model(self): inputs = Input(self.input_size) ...(省略)... self.model = Model(input=inputs, output=outputs) # ここがエラーになる
train.py
if __name__=='__main__': ...(省略)... model = model.MyModel((c.IMG_SIZE, c.IMG_SIZE, channel), batch_size, training_epochs)
クラスのインスタンスを作る際に,input_sizeには(c.IMG_SIZE, c.IMG_SIZE, channel)を指定していました. こうしないと,self.model = Model(input=inputs, output=outputs)がエラーになった為です.
結果,ニューラルネットワークに画像を入力する際にクラス数ぶんだけ画像を複製する帳尻合わせが生じます. 上記の例の場合にはクラス数が6だったので,6枚の画像をニューラル・ネットワークに入力する事になります. 動きはするのですが,メモリを無駄に消費しますし,何より気持ちが悪いです. "RGBの3チャンネルを入力 & 出力クラスは3限定"なんて言うソースも散見します.
ちなみに,Tensorflow単独ではこんな現象は無く,画像1枚をニューラル・ネットワークに入力出来ていました*4.
train.py
if __name__=='__main__': with tf.device("/gpu:0"): with tf.Graph().as_default(): with tf.variable_scope("scope_model"): x = tf.placeholder(tf.float32, [None, IMG_SIZE*IMG_SIZE]) # inputs(gray image) y = tf.placeholder(tf.float32, [None, IMG_SIZE*IMG_SIZE*CATEGORY]) # teacher
2. Tensorflow ver.2によるU-Net
ひょんなことから,Tensorflow ver.2では上記の不具合(?)が解消されている事に気づきました.
model.py
class MyModel: def __init__(self, input_size, num_classes, batch_size, epochs): self.input_size = input_size ...(省略)... def create_model(self): inputs = Input(self.input_size) ...(省略)... self.model = Model(inputs=inputs, outputs=outputs)
train.py
if __name__=='__main__': ... model = model.MyModel((c.IMG_SIZE, c.IMG_SIZE, 1), channel, batch_size, training_epochs)
これに合わせて,ソース全体を書き直しました*5.
3. むすび
厳密には,どのver.から不具合が治ったのか判っていません. 明らかな変化があったのは下記のver.になります.
不具合あり: tensorflow-gpu ver. 1.14.0 & keras ver. 2.3.1
不具合なし: tensorflow-gpu ver. 2.3.0
何だったんだ?って感じですが,,,モヤモヤが少し晴れました. ちなみに,U-Netの様な有名なモデルはライブラリ化されていることが多いです. 例えば*6. 多くの方は,ライブラリを上手に活用されているのでしょう. モデル構造を自分で書くという非効率な事をやっている為に招いた事態と言えますね(..)
4. おまけ
記事の内容とは全く関係ないんですが,久しぶりにソースを書いたらgithubの使い方が変わっていました. 認証キーを発行しないとpush出来なくなったんですね. キーの更新の仕方等を理解していないので,慣れていかないとです.
*1:https://changlikesdesktop.hatenablog.com/entry/2020/05/25/090818
*2:https://changlikesdesktop.hatenablog.com/entry/2020/07/19/132644
*3:https://github.com/changGitHubJ/U-Net_channels
*4:https://github.com/changGitHubJ/U-net_DAGM_categories/blob/master/train.py
*5:https://github.com/changGitHubJ/U-Net_tfv2
*6:https://segmentation-models.readthedocs.io/en/latest/tutorial.html