TensorFlow 模型的保存和恢复代码
代码参考《TensorFlow:实战Google深度学习框架》,本地手打,调试后复制出来,和原文会有差别。
不同于普通的保存和读取,读取的时候还是需要定义一下数据。之前想着 TensorFlow 训练好的模型,不能每次都要重新跑吧,先看了一下 Saver 相关的内容。
TensorFlow 官方文档地址:https://www.tensorflow.org/api_docs/python/tf/train/Saver
save demo
import tensorflow as tf v1=tf.Variable(tf.constant(1.0,shape=[1]),name="v1") v2=tf.Variable(tf.constant(2.0,shape=[1]),name="v2") result=v1+v2 init_op=tf.global_variables_initializer() saver=tf.train.Saver() with tf.Session() as sess: sess.run(init_op) saver.save(sess,"./model/model.ckpt")
restore demo
import tensorflow as tf v1=tf.Variable(tf.constant(1.0,shape=[1]),name="v1") v2=tf.Variable(tf.constant(2.0,shape=[1]),name="v2") result=v1+v2 saver=tf.train.Saver() with tf.Session() as sess: saver.restore(sess,"./model/model.ckpt") print(sess.run(result))
原创文章,作者:fendouai,如若转载,请注明出处:https://panchuang.net/2017/07/07/tensorflow-save-restore/