申请专栏作者 参展
投稿发布
您的当前位置:主页 > 机器学习 > 正文

Tensorflow 的 checkpoint 教程

来源: 时间:2019-10-24
请支持本站,点击下面的广告后浏览!

checkpoint 主要的目的有两个: 可思数据sykv.com

1. 如果训练过程中出现的意外情况,可以通过 checkpoint 快速恢复
2. 通过 checkpoint 可以stop early,这样使得算法效果更好 可思数据-数据挖掘,智慧医疗,机器视觉,机器人sykv.com

keras

在 keras 中使用Model.save_weights方法来生成 checkpoint.
但是如果使用这个方法的话,Model 的 layer 必须分配给一个成员变量,特别是在构造器中?
并且如果是Model.save_weights方法生成的 checkpoint, 需要使用Model.load_weights来加载,不能使用tf.train.Checkpoint.restore来进行加载。 本文来自可思数据(sykv.com),转载请联系本站及注明出处

API 文档中建议使用tf.train.Checkpoint来生成 checkpoint 本文来自可思数据(sykv.com),转载请联系本站及注明出处

tf.train.Checkpoint

这个包是 tensorflow 中负责 checkpoint 的全生命周期的管理,包括:
1. 定义 checkpoint 生成策略
2. 管理 checkpoint 的恢复

本文来自可思数据(sykv.com),转载请联系本站及注明出处
# import
from __future__ import absolute_import,division,print_function,\
unicode_literals
import tensorflow as tf
 

可思数据-开元棋牌是个坑_代理开元棋牌的平台_你们怎么看开元棋牌资讯平台sykv.com

/Users/ki/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
 可思数据-AI,sykv.com开元棋牌是个坑_代理开元棋牌的平台_你们怎么看开元棋牌,深度学习,机器学习,神经网络 

定义 checkpoint 生成策略

在使用 checkpoint 之前,首先需要我们定义一个简单的网络与一个简单的输入,就像quickstart2中所介绍的构建方式一样

可思数据sykv.com,sykv.cn

class Net(tf.keras.Model):
    """just a simple linear model"""

    def __init__(self):
        super(Net,self).__init__()
        self.l1 = tf.keras.layers.Dense(5)
    def call(self,x):
        return self.l1(x)

def toy_dataset():
    inputs = tf.range(10.)[:,None]
    labels = inputs * 5. + tf.range(5.)[None,:]
    return tf.data.Dataset.from_tensor_slices(
        dict(x=inputs,y=labels)).repeat(10).batch(2)

def train_step(net,example,optimizer):
    """train net on example using optimizer"""
    with tf.GradientTape() as tape:
        output = net(example['x'])
        loss = tf.reduce_mean(tf.abs(output - example['y']))
    variables = net.trainable_variables
    gradients = tape.gradient(loss,variables)
    optimizer.apply_gradients(zip(gradients,variables))
    return loss
 内容来自可思数据sykv.com 

现在我希望在这个网络的训练过程中生成 checkpoint. 我需要怎么做?
首先需要明确的是,对于 tensorflow 来说他的主要的对象都是类似于tf.Variable,是一个拥有内部状态的一个对象,我们 checkout 的对象的状态,恢复的也是对象的状态,而不是恢复这个对象。
在以上的前提下存在 3 个问题:
1. 如何让对象被 checkpoint?
2. 什么对象才能被 checkpoint?
3. 如何从 checkpoint 恢复到对象中? 可思数据-数据挖掘,智慧医疗,机器视觉,机器人sykv.com


  • 本文地址:http://www.6aiq.com/article/1571815127141
  • 本文版权归作者和AIQ共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出
  • 知乎专栏 点击关注

tf.train.Checkpoint是 tf2.0 新增的功能,在 tf1.X 中由 train.Saver 进行支持。不进行赘述。在 tf2.0 中,Checkpoint 是基于 python 对象进行序列化。 可思数据sykv.com,sykv.cn

tf.train.Checkpoint类的构造器:__init__(**kwag), 通过构造时传入你所希望 checkout 的对象,然后在后续 checkout 过程中就会将所传入的对象进行 checkout。下面这段常规的代码可以回答这三个问题 可思数据-www.sykv.cn,sykv.com

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
    step = tf.Variable(1), #迭代轮数
    optimizer = opt, #优化器
    net = net #网络
)#ckpt相当于定义了一个容器,将构造器中传入的对象的格式固定下来,并链接到这个对象
manager = tf.train.CheckpointManager(
    checkpoint = ckpt, #checkpoint
    directory ='./tf_ckpts_test', #存储路径
    max_to_keep = 3, #最大checkpoint份数
    keep_checkpoint_every_n_hours = None, #checkpoint 时间间隔
    checkpoint_name = 'test' #checkpoint 文件名
)
for example in toy_dataset():
    ckpt.step.assign_add(1)
    ckpt.restore(manager.latest_checkpoint) #将最后一个checkout恢复
    loss = train_step(net,example,opt)
    print("loss {:1.2f}".format(loss.numpy()))
    manager.save() #根据将当前checkout容器中的对象状态保存
 可思数据-开元棋牌是个坑_代理开元棋牌的平台_你们怎么看开元棋牌资讯平台sykv.com 
loss 2.08
loss 0.98
loss 2.09
loss 3.38
loss 4.45
loss 2.53
loss 1.36
loss 1.28
loss 2.03
loss 2.66
loss 2.78
loss 1.96
loss 1.24
loss 0.82
loss 0.78
loss 2.79
loss 2.07
loss 1.34
loss 0.85
loss 0.68
loss 2.61
loss 1.78
loss 1.08
loss 0.75
loss 1.13
loss 2.37
loss 1.46
loss 0.68
loss 0.55
loss 1.33
loss 2.20
loss 1.37
loss 0.59
loss 0.49
loss 1.13
loss 2.07
loss 1.30
loss 0.54
loss 0.42
loss 1.05
loss 1.90
loss 1.20
loss 0.59
loss 0.32
loss 0.83
loss 1.76
loss 1.16
loss 0.60
loss 0.26
loss 0.71
 
可思数据sykv.com

其中 keep_checkpoint_every_n_hours 参数保存的 checkpoint 并不会收到 max_to_keep 参数的限制,max_to_keep 限制的是通过 save 函数主动 checkpoint 的数据

可思数据sykv.com

Q1 如何让对象被 checkpoint?

通过将对象通过Checkpoint构造器传入,也可以通过listed或者mapped来传入 list 或者 dictionary 对象来灵活的构造 Checkpoint 可思数据-开元棋牌是个坑_代理开元棋牌的平台_你们怎么看开元棋牌资讯平台sykv.com

Q2 什么对象才能 checkpoint?

TrackableBase 派生出来的对象 才能被 checkout 本文来自可思数据(sykv.com),转载请联系本站及注明出处

testList =  []
ckpt_error = tf.train.Checkpoint(
    test = testList,
    step = tf.Variable(1), #记录迭代轮数
    optimizer = opt, #记录优化器状态
    net = net #记录网络状态
)
 
可思数据sykv.com,sykv.cn
---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

 in ()
      4     step = tf.Variable(1), #记录迭代轮数
      5     optimizer = opt, #记录优化器状态
----> 6     net = net #记录网络状态
      7 )

~/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/training/tracking/util.py in __init__(self, **kwargs)
   1777              "object should be trackable (i.e. it is part of the "
   1778              "TensorFlow Python API and manages state), please open an issue.")
-> 1779             % (v,))
   1780       setattr(self, k, v)
   1781     self._save_counter = None  # Created lazily for restore-on-create.

ValueError: `Checkpoint` was expecting a trackable object (an object derived from `TrackableBase`), got []. If you believe this object should be trackable (i.e. it is part of the TensorFlow Python API and manages state), please open an issue.
 内容来自可思数据sykv.com 

Q3 如何从 checkpoint 恢复到对象中?

定义相同参数的对象,然后将这些对象构造一个 Checkout,然后调用 restore 方法,从指定的路径上恢复。 可思数据sykv.com,sykv.cn

#重复restore会产生警告
opt3 = tf.keras.optimizers.Adam(0.1)
net3 = Net()
step3 = tf.Variable(1)
print("restore前:{}".format(step3.numpy()))
ckpt3 = tf.train.Checkpoint(
    optimizer=opt3,
    step=step3
    ,net=net3)
file = tf.train.latest_checkpoint('./tf_ckpts_test')
print(file)
status = ckpt3.restore(file)
print("restore后:{}".format(step3.numpy()))
 内容来自可思数据sykv.com 
restore前:1
./tf_ckpts_test/test-100
restore后:51
 可思数据-AI,sykv.com开元棋牌是个坑_代理开元棋牌的平台_你们怎么看开元棋牌,深度学习,机器学习,神经网络 

Q3.1 为什么这里好像显示 net3 没有被加载成功呢?调用 trainable_variables 显示的结果不同

net3.trainable_variables
 
可思数据-数据挖掘,智慧医疗,机器视觉,机器人sykv.com
[]
 本文来自可思数据(sykv.com),转载请联系本站及注明出处 
net.trainable_variables
 
可思数据-www.sykv.cn,sykv.com
[,
 ]
 可思数据sykv.com 

原因在于 restore 是延迟加载(Delayed restorations)。Layer 对象会将其内部的 Variable 的创建延迟到其首次调用。 可思数据-AI,sykv.com开元棋牌是个坑_代理开元棋牌的平台_你们怎么看开元棋牌,深度学习,机器学习,神经网络

Estimator 对 Checkpoint 的额外支持

Estimator 有一个默认的 CheckoutManager,只要你在 model_fn 内部构造了 Checkpoint 对象。那么在训练中就会保存下来每一轮的模型,但只保留最新的一份 checkpoint。

可思数据sykv.com,sykv.cn

import tensorflow.compat.v1 as tf_compat
def model_fn(features, labels, mode):
    net = Net() #定义模型
    opt = tf.keras.optimizers.Adam(0.1) #定义迭代器
    ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net) #定义Checkout
    #不需要定义CheckoutManager,因为Estimator自带默认的CheckoutManager
    #开始训练
    with tf.GradientTape() as tape:
        output = net(features['x'])
        loss = tf.reduce_mean(tf.abs(output - features['y']))
        variables = net.trainable_variables
        gradients = tape.gradient(loss, variables)
    return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn,# 模型
                             './tf_estimator_example/'#checkout路径与模型路径
                            )
est.train(toy_dataset, steps=10)
 可思数据sykv.com,sykv.cn 
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': , '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from ./tf_estimator_example/model.ckpt-10
WARNING:tensorflow:From /Users/ki/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py:1069: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file utilities to get mtimes.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:loss = 3.5265698, step = 11
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 12 vs previous value: 12. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 15 vs previous value: 15. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 18 vs previous value: 18. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:Saving checkpoints for 20 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Loss for final step: 33.14527.
 内容来自可思数据sykv.com 
# 恢复checkpoint
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy() #20 

可思数据-www.sykv.cn,sykv.com


转发量:

网友评论:

发表评论
请自觉遵守互联网相关的政策法规,严禁发布色情、暴力、反动的言论。
评价:
表情:
用户名: 验证码:点击我更换图片 匿名?
数据标注服务

关于我们   免责声明   广告合作   版权声明   联系方式   原创投稿   网站地图  

Copyright©2005-2019 Sykv.com 可思数据 版权所有    ICP备案:京ICP备14056871号

开元棋牌是个坑_代理开元棋牌的平台_你们怎么看开元棋牌资讯   开元棋牌是个坑_代理开元棋牌的平台_你们怎么看开元棋牌资讯   开元棋牌是个坑_代理开元棋牌的平台_你们怎么看开元棋牌资讯   开元棋牌是个坑_代理开元棋牌的平台_你们怎么看开元棋牌资讯

扫码入群
咨询反馈
扫码关注

微信公众号

返回顶部
关闭