代码参考《TensorFlow:实战Google深度学习框架》,本地手打,调试后复制出来,和原文会有差别。
不同于普通的保存和读取,读取的时候还是需要定义一下数据。之前想着 TensorFlow 训练好的模型,不能每次都要重新跑吧,先看了一下 Saver 相关的内容。
TensorFlow 官方文档地址:https://www.tensorflow.org/api_docs/python/tf/train/Saver
save demo
import tensorflow as tf
v1=tf.Variable(tf.constant(1.0,shape=[1]),name="v1")
v2=tf.Variable(tf.constant(2.0,shape=[1]),name="v2")
result=v1+v2
init_op=tf.global_variables_initializer()
saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
saver.save(sess,"./model/model.ckpt")
restore demo
import tensorflow as tf
v1=tf.Variable(tf.constant(1.0,shape=[1]),name="v1")
v2=tf.Variable(tf.constant(2.0,shape=[1]),name="v2")
result=v1+v2
saver=tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,"./model/model.ckpt")
print(sess.run(result))