Inception_V3的运行

个人笔记 感谢指正

1.导入包

import tensorflow as tf
import os
import re
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

2.定义一个处理文件的类

class NodeLookup(object):
    def __init__(self):
        label_lookup_path = 'inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'
        uid_lookup_path = 'inception_model/imagenet_synset_to_human_label_map.txt'
        self.node_lookup = self.load(label_lookup_path,uid_lookup_path)

    def load(self,label_lookup_path,uid_lookup_path):
        #加载分类字符串n**********对应分类名称的文件
        proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()  #总结一下tf.gfile.GFile()函数
                                                                            #proto_as_ascii_lines是一个列表,是一个可迭代对象
        uid_to_human = {}
        #一行一行的读取数据
        for line in proto_as_ascii_lines:
            #去掉换行符
            line = line.strip('\n')
            parsed_items = line.split('\t')
            #获取分类编号
            uid = parsed_items[0]
            #获取分类名称
            human_string = parsed_items[1]
            #用编号和分类名称新建一个字典
            uid_to_human[uid] = human_string

        #加载分类字符串n********对应的分类编号1-1000
        proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
        node_id_to_uid = {}
        for line in proto_as_ascii:
            if line.startswith('  target_class:'):
                #获取分类编号
                target_class = int(line.split(': ')[1])
            if line.startswith('  target_class_string:'):
                target_class_string = line.split(': ')[1]
                #保存分类编号1-1000与编号字符串n**********映射关系
                node_id_to_uid[target_class] = target_class_string[1:-2]

        #建立分类编号1-1000与分类名称的关系
        node_id_to_name = {}
        for key,val in node_id_to_uid.items():
            #获取分类名称
            name = uid_to_human[val]
            #建立分类编号1-1000到分类名称的映射
            node_id_to_name[key] = name
        return node_id_to_name

    #传入分类编号,返回分类名称
    def id_to_string(self,node_id):
        if node_id not in self.node_lookup:  #只比较键值吗?是的
            return ''
        return self.node_lookup[node_id]

#tf.gfile.GFile()函数:获取文本操作句柄,类似于python提供的文本操作open()函数,filename是要打开的文件名,mode是以何种方式去读写,将会返回一个文本操作句柄。

#readlines()函数:Returns all lines from the file in a list.
#readline()一样返回的是字符串,但是每次只读一行:Reads the next line from the file. Leaves the ‘\n’ at the end.
#read():函数是将文件读成一个字符串,返回一个字符串,是一次性读完

#split()函数:Return a list of the sections in the bytes, using sep as the delimiter.

#startsWith()函数:方法用来判断当前字符串是否是以另外一个给定的子字符串“开头”的,根据判断结果返回 true 或 false。
参数:str.startsWith(searchString [, position]);
searchString : 要搜索的子字符串。
position : 在 str 中搜索 searchString 的开始位置,默认值为 0,也就是真正的字符串开头处。

#if node_id not in self.note_lookup: #只比较键值吗?是的

3.创建一个图来存放google训练好的模型

with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb','rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def,name='')

4.编写会话

with tf.Session() as sess:
    softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')  #获取计算节点或tensor
                                                                 #sess.graph:The graph that was launched in this session
    #遍历目录
    for root,dirs,files in os.walk('images/'):  #总结一下os.walk
        for file in files:
            #载入图片
            image_data = tf.gfile.FastGFile(os.path.join(root,file),'rb').read()
            predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0':image_data})  #图片是Jpg格式
            predictions = np.squeeze(predictions)  #返回的是每个类的概率

            #打印图片路径及名称
            image_path = os.path.join(root,file)
            print(image_path)

            #显示图片
            img = Image.open(image_path)  #只有打开文件才可以对文件操作!!!很多包都有打开文件的API
            plt.imshow(img)
            plt.axis('off')
            plt.show()

            #排序
            top_k = predictions.argsort()[-5:][::-1]  #argsort()函数返回的是坐标,但这里的坐标正好是类的node_id
            node_lookup = NodeLookup()
            for node_id in top_k:
                #获取分类名称
                human_string = node_lookup.id_to_string(node_id)
                #获取该分量的置信度
                score = predictions[node_id]
                print('%s (score = %.5f)' % (human_string,score))
            print()  #换行的作用

#os.walk()函数:

top 	-- 是你所要遍历的目录的地址.
topdown -- 可选,为 True,则优先遍历top目录,
	   否则优先遍历 top 的子目录(默认为开启)。
onerror -- 可选,需要一个 callable 对象,当 walk 需要异常时,会调用。
followlinks -- 可选,如果为 True,则会遍历目录下的快捷方式,默认开启 
return None
   该函数没有返回值会使用yield关键字抛出一个存放当前
   该层目录(root,dirs,files)的三元组,最终将所有目录层的的结果变为一个生成器
   root  所指的是当前正在遍历的这个文件夹的本身的地址
   dirs  是一个 list ,内容是该文件夹中所有的目录的名字(不包括子目录)
   files 同样是 list , 内容是该文件夹中所有的文件(不包括子目录)	

注:返回值有多个这样的元组(root,dirs,files),例:
在这里插入图片描述

print(list(os.walk('/mnt',topdown = False)))

输出:

[('/mnt', ['dira', 'dirb'], ['file1', 'file2']), ('/mnt/dira', [], ['file_test_A']), ('/mnt/dirb', [], ['file_test_B'])]
参考来源:https://blog.csdn.net/m0_37717595/article/details/80359272

#tf.graph.get_tensor_by_name(‘softmax:0’)函数:Returns the object referred to by obj, as an Operation or `Tensor
注:'softmax:0’中0的意思是这个张量是计算节点上的第几个结果

"""Returns the `Tensor` with the given `name`.

    This method may be called concurrently from multiple threads.

    Args:
      name: The name of the `Tensor` to return.

    Returns:
      The `Tensor` with the given `name`.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值