TensorFlow 模型的保存和恢复代码

代码参考《TensorFlow:实战Google深度学习框架》,本地手打,调试后复制出来,和原文会有差别。

不同于普通的保存和读取,读取的时候还是需要定义一下数据。之前想着 TensorFlow 训练好的模型,不能每次都要重新跑吧,先看了一下  Saver 相关的内容。

TensorFlow 官方文档地址:https://www.tensorflow.org/api_docs/python/tf/train/Saver

save demo

import  tensorflow as tf
v1=tf.Variable(tf.constant(1.0,shape=[1]),name="v1")
v2=tf.Variable(tf.constant(2.0,shape=[1]),name="v2")

result=v1+v2

init_op=tf.global_variables_initializer()
saver=tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess,"./model/model.ckpt")

restore demo

import tensorflow as tf

v1=tf.Variable(tf.constant(1.0,shape=[1]),name="v1")
v2=tf.Variable(tf.constant(2.0,shape=[1]),name="v2")

result=v1+v2
saver=tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess,"./model/model.ckpt")
    print(sess.run(result))

Related posts

Leave a Comment