オッサンはDesktopが好き

自作PCや機械学習、自転車のことを脈絡無く書きます

Lenux + tensorflowで作ったDeep Learningの学習済みネットワークを,Windowsに移植して推論する (2)

前回*1の続きです.
今回は学習済みネットワークをpbファイル形式で移植します.

1. チェックポイントをpbファイルに変換

チェックポイントを読み込み,pbファイルに変換して保存します.

freeze_model.py

import tensorflow as tf

def main():
    with tf.Session() as sess:
        # Restore the graph
        saver = tf.train.import_meta_graph("./model.meta", clear_devices=True)

        # Load weights
        model_path = tf.train.latest_checkpoint("./")
        saver.restore(sess, model_path)

        # Freeze the graph
        frozen_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            sess.graph_def,
            ['y', 'accuracy']
            )

        tf.train.write_graph(frozen_graph_def, './', 'frozen_graph.pb',  as_text=False)
        tf.train.write_graph(frozen_graph_def, './', 'frozen_graph.txt', as_text=True)

if __name__ == '__main__':
    main()

ポイントは,学習時に指定した出力層の名前('y', 'accuracy')を,nodeに指定するところです.

出力される↓を,Windowsに移植します.

  • frozen_graph.pb

Note: frozen_graph.txtは推論には使いませんが,node名を把握するのに便利なので出力しています.

2. pbファイルを読み込んで推論

pbファイルを読み込みます.

apply_model_pb.py

with open('frozen_graph.pb', 'rb') as f:
   graph_def = tf.GraphDef()
   graph_def.ParseFromString(f.read())
   tf.import_graph_def(graph_def, name='') 

落とし穴になったのが,チェックポイントの読み込みでは推論に必要な入力画像のみをネットワークに入れれば良かったのに対し,pbファイルの場合には形式的にラベルとドロップアウト値も入力しなければならない点でした.
Placeholderを再指定していないので,当然なのですが...

decoded_imgs = sess.run('y:0', feed_dict={'x:0': dataImages[0:30,:], 'y_:0': dataLabels[0:30, :], 'keep_prob:0': 1.0})
decoded_imgs = decoded_imgs.reshape([-1, OUTPUT_SIZE])

学習時に指定したPlaceholderの名前('y', 'x', 'y'_, 'keep_prob')を正しく指定することも重要です.
上記で出力したfrozen_graph.txtと,↓↓↓部分の出力で確認します.

print('=' * 60)
for op in tf.get_default_graph().get_operations():
   print(op.name)
   for output in op.outputs:
      print('  ', output.name)
   print('=' * 60)

チェックポイントからの推論と同じ結果になりました.

f:id:changlikesdesktop:20200402192657p:plain:w400

今回書いたソースはここ*2にあります.