个人笔记 感谢指正
1.导入包
import tensorflow as tf
import os
import re
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
2.定义一个处理文件的类
class NodeLookup(object):
def __init__(self):
label_lookup_path = 'inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'
uid_lookup_path = 'inception_model/imagenet_synset_to_human_label_map.txt'
self.node_lookup = self.load(label_lookup_path,uid_lookup_path)
def load(self,label_lookup_path,uid_lookup_path):
#加载分类字符串n**********对应分类名称的文件
proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines() #总结一下tf.gfile.GFile()函数
#proto_as_ascii_lines是一个列表,是一个可迭代对象
uid_to_human = {}
#一行一行的读取数据
for line in proto_as_ascii_lines:
#去掉换行符
line = line.strip('\n')
parsed_items = line.split('\t')
#获取分类编号
uid = parsed_items[0]
#获取分类名称
human_string = parsed_items[1]
#用编号和分类名称新建一个字典
uid_to_human[uid] = human_string
#加载分类字符串n********对应的分类编号1-1000
proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
node_id_to_uid = {}
for line in proto_as_ascii:
if line.startswith(' target_class:'):
#获取分类编号
target_class = int(line.split(': ')[1])
if line.startswith(' target_class_string:'):
target_class_string = line.split(': ')[1]
#保存分类编号1-1000与编号字符串n**********映射关系
node_id_to_uid[target_class] = target_class_string[1:-2]
#建立分类编号1-1000与分类名称的关系
node_id_to_name = {}
for key,val in node_id_to_uid.items():
#获取分类名称
name = uid_to_human[val]
#建立分类编号1-1000到分类名称的映射
node_id_to_name[key] = name
return node_id_to_name
#传入分类编号,返回分类名称
def id_to_string(self,node_id):
if node_id not in self.node_lookup: #只比较键值吗?是的
return ''
return self.node_lookup[node_id]
#tf.gfile.GFile()函数:获取文本操作句柄,类似于python提供的文本操作open()函数,filename是要打开的文件名,mode是以何种方式去读写,将会返回一个文本操作句柄。
#readlines()函数:Returns all lines from the file in a list.
#readline()一样返回的是字符串,但是每次只读一行:Reads the next line from the file. Leaves the ‘\n’ at the end.
#read():函数是将文件读成一个字符串,返回一个字符串,是一次性读完
#split()函数:Return a list of the sections in the bytes, using sep as the delimiter.
#startsWith()函数:方法用来判断当前字符串是否是以另外一个给定的子字符串“开头”的,根据判断结果返回 true 或 false。
参数:str.startsWith(searchString [, position]);
searchString : 要搜索的子字符串。
position : 在 str 中搜索 searchString 的开始位置,默认值为 0,也就是真正的字符串开头处。
#if node_id not in self.note_lookup: #只比较键值吗?是的
3.创建一个图来存放google训练好的模型
with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb','rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def,name='')
4.编写会话
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0') #获取计算节点或tensor
#sess.graph:The graph that was launched in this session
#遍历目录
for root,dirs,files in os.walk('images/'): #总结一下os.walk
for file in files:
#载入图片
image_data = tf.gfile.FastGFile(os.path.join(root,file),'rb').read()
predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0':image_data}) #图片是Jpg格式
predictions = np.squeeze(predictions) #返回的是每个类的概率
#打印图片路径及名称
image_path = os.path.join(root,file)
print(image_path)
#显示图片
img = Image.open(image_path) #只有打开文件才可以对文件操作!!!很多包都有打开文件的API
plt.imshow(img)
plt.axis('off')
plt.show()
#排序
top_k = predictions.argsort()[-5:][::-1] #argsort()函数返回的是坐标,但这里的坐标正好是类的node_id
node_lookup = NodeLookup()
for node_id in top_k:
#获取分类名称
human_string = node_lookup.id_to_string(node_id)
#获取该分量的置信度
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string,score))
print() #换行的作用
#os.walk()函数:
top -- 是你所要遍历的目录的地址.
topdown -- 可选,为 True,则优先遍历top目录,
否则优先遍历 top 的子目录(默认为开启)。
onerror -- 可选,需要一个 callable 对象,当 walk 需要异常时,会调用。
followlinks -- 可选,如果为 True,则会遍历目录下的快捷方式,默认开启
return None
该函数没有返回值会使用yield关键字抛出一个存放当前
该层目录(root,dirs,files)的三元组,最终将所有目录层的的结果变为一个生成器
root 所指的是当前正在遍历的这个文件夹的本身的地址
dirs 是一个 list ,内容是该文件夹中所有的目录的名字(不包括子目录)
files 同样是 list , 内容是该文件夹中所有的文件(不包括子目录)
注:返回值有多个这样的元组(root,dirs,files),例:
print(list(os.walk('/mnt',topdown = False)))
输出:
[('/mnt', ['dira', 'dirb'], ['file1', 'file2']), ('/mnt/dira', [], ['file_test_A']), ('/mnt/dirb', [], ['file_test_B'])]
参考来源:https://blog.csdn.net/m0_37717595/article/details/80359272
#tf.graph.get_tensor_by_name(‘softmax:0’)函数:Returns the object referred to by obj
, as an Operation
or `Tensor
注:'softmax:0’中0的意思是这个张量是计算节点上的第几个结果
"""Returns the `Tensor` with the given `name`.
This method may be called concurrently from multiple threads.
Args:
name: The name of the `Tensor` to return.
Returns:
The `Tensor` with the given `name`.