Tensorflow在不同训练场景下读取和使用不同格式pretrained model的方法

本文介绍了在TensorFlow中加载预训练模型的多种方法,包括从断点继续训练、加载特定层权重、从不同格式(如checkpoint、frozen graph、.npy文件)读取模型等,适用于迁移学习等多种场景。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

不同应用场景分析与示例

Tensorflow读取预训练模型是模型训练中常见的操作,通常的应用的场景包括:

1)训练中断后需要重新开始,将保存之前的checkpoint(包括.data .meta .index checkpoint这四个文件),然后重新加载模型,从上次断点处继续训练或预测。实现方法如下:

  • 如果代码中已经构建好了网络结构图
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints_path'))
    # 如果checkpoint存在则加载断点之前的训练模型
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
  • 如果用别人完整的checkpoint文件,但自己没有搭建网络结构的代码,可以通过.meta文件来重构网络,并检查权重数据:
# ref: http://blog.csdn.net/spylyt/article/details/71601174
import tensorflow as tf 

def restore_from_meta(sess, graph, ckpt_path):
    with graph.as_default():
        ckpt = tf.train.get_checkpoint_state(ckpt_path)
        if ckpt and ckpt.model_checkpoint_path:
            print('Found checkpoint, try to restore...')
            saver = tf.train.import_meta_graph(''.join([ckpt.model_checkpoint_path, '.meta']))
            saver.restore(sess, ckpt.model_checkpoint_path)

        #### Check
        # 打印出网络中可训练的权重参数名
        for var in tf.trainable_variables():
            print(var)
        # 根据变量名称读取所需变量权重,打印其形状
        conv1_2_weights = graph.get_tensor_by_name('conv1_2/weights:0')
        print(conv1_2_weights.shape)

if __name__ == '__main__':
    ckpt_path = 'data\\VOC2007_images\\GAP_backup\\'
    graph = tf.Graph()
    sess = tf.Session(graph=graph)
    restore_from_meta(sess, graph, ckpt_path)

2)用上述f.train.Saver().restore(sess, ckpt_path)方法来加载模型会将所有权重全部加载,如果你希望加载指定几层权重(比如在做transfer learning的时候),可以通过下面方法实现:

# ref: http://blog.csdn.net/qq_25737169/article/details/78125061
sess = tf.Session()
var = tf.global_variables()
var_to_restore = [val  for val in var if 'conv1' in val.name or 'conv2'in val.name]
saver = tf.train.Saver(var_to_restore )
saver.restore(sess, os.path.join(model_dir, model_name))
var_to_init = [val  for val in var if 'conv1' not in val.name or 'conv2'not in val.name]
tf.initialize_variables(var_to_init)

这样,就值加载了conv1conv2的权重,而其他层中权重则通过网络初始化的方式获得。

如果用tensorflowslim来构建网络的话,操作会更加简单:

exclude = ['layer1', 'layer2']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, os.path.join(model_dir, model_name))

该操作是restore除了layer1layer2之外的其他权重。


上述两种情形是自己搭建了网络结构图或者虽然自己没有网络结构图,但手头有对应的.meta文件,可以重构网络。如果只是有该网络的预训练模型权重,希望借用该模型中的某些层的权重来初始化自己的网络。这时候面对的场景就不一样了。

一般的模型权重会保存成:

1)checkpoint权重(model_ckpt.data)或者frozen graph(.pb)格式的文件

  • model_ckpt.data权重可通过tensorflow.python.pywrap_tensorflow.NewCheckpointReadertf.train.NewCheckpointReader 获取,二者的功能一样,下面以前者举例:
import os
from tensorflow.python import pywrap_tensorflow

checkpoint_path = os.path.join(model_dir, "model_ckpt.ckpt")
# 从checkpoint中读出数据
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法
var_to_shape_map = reader.get_variable_to_shape_map()
# 输出权重tensor名字和值
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))

如果你想用某些层的pretrained权重来初始化你自己的网络,可以在sess中通过下面的操作完成:

with tf.variable_scope('', reuse = True):
        sess.run(tf.get_variable(your_var_name).assign(reader.get_tensor(pretrained_var_name)))
  • .pb权重图可通过下面的方法获取:
import os
import tensorflow as tf

var1_name = 'pool_3/_reshape:0'
var2_name = 'data/content:0'

graph = tf.Graph()
with graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile('model_graph.pb', 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        var1_tensor, var2_tensor = tf.import_graph_def(od_graph_def, return_elements = [var1_name, var2_name])

得到变量1 pool_3/_reshape:0和变量2 data/content:0的值。

2)如果预训练模型是由caffe model转过来的,这时候可能的格式为numpy文件(.npy),它按照字典的方式来存储各层的权重数据。下面举一个例子说明:

import numpy as np
import cPickle

layer_name = "conv1"
with open('model.npy') as f:
    pretrained_weight = cPickle.load(f)
    layer = pretrained_weight[layer_name]
    weights = layer[0]
    biases = layer[1]

获取conv1中的变量,然后将权重和偏置项分别取出,注意权重的顺序。

如果想用conv1 的权重和偏置来初始化自己网络中的conv1层,则可用下面的方法:

with tf.Session() as sess:
    with tf.variable_scope(layer_name, reuse=True):
        for subkey, data in zip('weights', 'biases'), pretrained_weight[layer_name]:
            sess.run(tf.get_variable(subkey).assign(data))

总结

以上是tensorflow中常见的预训练模型操作,在实际操作过程中根据自己的场景需求选择使用。

### 如何查找或下载名为 `traffic model.h5` 的预训练模型 为了找到并使用名为 `traffic model.h5` 的预训练模型,可以按照以下方法操作: #### 1. 查找模型资源 通常情况下,预训练模型可以从公开的机器学习框架库或者社区分享平台获取。例如 TensorFlow Hub、Keras Model Zoo 或者 GitHub 上的相关项目。 如果目标是交通场景识别模型,则可以通过搜索引擎输入关键词如 “pretrained traffic model h5” 来定位可能的开源实现[^4]。一些常见的模型托管站点包括但不限于: - **TensorFlow Hub**: 提供大量经过优化的预训练模型。 - **Hugging Face Models**: 虽然主要面向自然语言处理领域,但也支持计算机视觉任务。 - **GitHub Projects**: 许多开发者会将自己的研究成果发布到此平台上。 对于具体名称为 `traffic model.h5` 的文件,如果没有明确出处说明其来源的话,建议尝试访问上述提到的地方寻找相似功能描述的产品替代品。 #### 2. 下载与加载模型 假设已经找到了合适的 `.h5` 文件链接地址,在本地保存之后可通过 Python 编程环境完成导入工作流程如下所示: ```python import tensorflow as tf # 加载 HDF5 格式的 Keras 模型 model = tf.keras.models.load_model('path_to_your_traffic_model.h5') # 打印模型结构概览 print(model.summary()) ``` 这里需要注意路径替换为自己实际存储位置;另外考虑到不同版本间可能存在兼容性问题所以最好确认所使用Tensorflow/Keras 版本号是否匹配官方文档推荐配置[^5]。 #### 3. 使用模型进行预测 一旦成功加载了该模型实例后就可以调用它来进行新的数据集上的推理计算啦! ```python from PIL import Image import numpy as np def preprocess_image(image_path): """图像前处理函数""" img = Image.open(image_path).resize((224, 224)) # 假设输入尺寸固定为224x224像素大小 arr = np.array(img)/255. return np.expand_dims(arr,axis=0) test_img = './example.jpg' input_data = preprocess_image(test_img) predictions = model.predict(input_data) print(predictions.argmax()) # 输出类别索引编号 ``` 以上脚本定义了一个简单的辅助方法用于读取图片并将之转换成适合喂给神经网络的形式再执行分类判断动作[^6]。 ---
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值