前回*1の続きです.
今回は学習済みネットワークをpbファイル形式で移植します.
1. チェックポイントをpbファイルに変換
チェックポイントを読み込み,pbファイルに変換して保存します.
freeze_model.py
import tensorflow as tf
def main():
with tf.Session() as sess:
# Restore the graph
saver = tf.train.import_meta_graph("./model.meta", clear_devices=True)
# Load weights
model_path = tf.train.latest_checkpoint("./")
saver.restore(sess, model_path)
# Freeze the graph
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph_def,
['y', 'accuracy']
)
tf.train.write_graph(frozen_graph_def, './', 'frozen_graph.pb', as_text=False)
tf.train.write_graph(frozen_graph_def, './', 'frozen_graph.txt', as_text=True)
if __name__ == '__main__':
main()
ポイントは,学習時に指定した出力層の名前('y', 'accuracy')を,nodeに指定するところです.
出力される↓を,Windowsに移植します.
- frozen_graph.pb
Note: frozen_graph.txtは推論には使いませんが,node名を把握するのに便利なので出力しています.
2. pbファイルを読み込んで推論
pbファイルを読み込みます.
apply_model_pb.py
with open('frozen_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
落とし穴になったのが,チェックポイントの読み込みでは推論に必要な入力画像のみをネットワークに入れれば良かったのに対し,pbファイルの場合には形式的にラベルとドロップアウト値も入力しなければならない点でした.
Placeholderを再指定していないので,当然なのですが...
decoded_imgs = sess.run('y:0', feed_dict={'x:0': dataImages[0:30,:], 'y_:0': dataLabels[0:30, :], 'keep_prob:0': 1.0})
decoded_imgs = decoded_imgs.reshape([-1, OUTPUT_SIZE])
学習時に指定したPlaceholderの名前('y', 'x', 'y'_, 'keep_prob')を正しく指定することも重要です.
上記で出力したfrozen_graph.txtと,↓↓↓部分の出力で確認します.
print('=' * 60)
for op in tf.get_default_graph().get_operations():
print(op.name)
for output in op.outputs:
print(' ', output.name)
print('=' * 60)
チェックポイントからの推論と同じ結果になりました.

今回書いたソースはここ*2にあります.