利用类激活映射和t-SNE探索网络预测与行为
立即解锁
发布时间: 2025-09-06 00:17:03 阅读量: 1 订阅数: 55 AIGC 

### 利用类激活映射和 t-SNE 探索网络预测与行为
#### 1. 利用类激活映射研究网络预测
在图像分类任务中,深度学习网络常被视为“黑盒”,我们难以知晓网络学到了什么,也不清楚输入的哪部分对预测结果负责。当模型给出错误预测时,往往没有任何预警或解释。类激活映射(Class Activation Mapping,CAM)是一种可以直观解释卷积神经网络预测结果的技术。
##### 1.1 类激活映射的作用
- **解释错误预测**:通过类激活映射,我们可以检查输入图像的特定部分是否“迷惑”了网络,导致其做出错误预测。例如,在训练区分猫和狗的网络时,网络在训练集上准确率很高,但在实际应用中表现不佳。使用类激活映射分析训练样本后,发现网络的预测并非基于图像中的猫和狗,而是背景。原来训练集中所有猫的图片背景都是红色,狗的图片背景都是绿色,网络在训练时学习到的是背景颜色。此时,我们可以收集无此偏差的新数据,提高网络的鲁棒性。
- **识别训练集偏差**:有助于发现训练集中的偏差,进而提高模型的准确性。
##### 1.2 具体操作步骤
- **加载预训练网络和摄像头**
- 选择合适的预训练卷积神经网络,如 SqueezeNet、GoogLeNet、ResNet - 18 和 MobileNet - v2 等相对较快的网络。注意,不能对网络末尾有多个全连接层的网络(如 AlexNet、VGG - 16 和 VGG - 19)使用类激活映射。
```matlab
netName = ;
net = eval(netName);
camera = webcam;
inputSize = net.Layers(1).InputSize(1:2);
classes = net.Layers(end).Classes;
layerName = activationLayerName(netName);
```
- **显示类激活映射**
- 创建一个图形窗口,并在循环中执行类激活映射。循环会持续运行,直到关闭图形窗口。
```matlab
h = figure('Units','normalized','Position',[0.05 0.05 0.9 0.8],'Visible','on');
while ishandle(h)
im = snapshot(camera);
imResized = imresize(im,[inputSize(1), NaN]);
imageActivations = activations(net,imResized,layerName);
scores = squeeze(mean(imageActivations,[1 2]));
if netName ~= "squeezenet"
fcWeights = net.Layers(end-2).Weights;
fcBias = net.Layers(end-2).Bias;
scores = fcWeights*scores + fcBias;
[~,classIds] = maxk(scores,3);
weightVector = shiftdim(fcWeights(classIds(1),:),-1);
classActivationMap = sum(imageActivations.*weightVector,3);
else
[~,classIds] = maxk(scores,3);
classActivationMap = imageActivations(:,:,classIds(1));
end
scores = exp(scores)/sum(exp(scores));
maxScores = scores(classIds);
labels = classes(classIds);
subplot(1,2,1)
imshow(im)
subplot(1,2,2)
CAMshow(im,classActivationMap)
title(string(labels) + ", " + string(maxScores));
drawnow
end
clear camera
```
- **示例映射分析**
- **正确分类示例**:网络将一张图片中的物体正确识别为乐福鞋(一种鞋子)。类激活映射显示,输入图像的整个鞋子部分都对预测结果有贡献,其中鞋尖和鞋口的红色区域贡献最大。
- **错误分类示例**:
- 网络将一张图片分类为鼠标,类激活映射显示,预测不仅基于图片中的鼠标,还包括键盘。这可能是因为训练集中有很多鼠标在键盘旁边的图片,导致网络认为包含键盘的图片更可能包含鼠标。
- 网络将一张咖啡杯的图片误分类为搭扣,类激活映射显示,网络误分类是因为图片中存在过多干扰物体,网络检测并聚焦在了手表腕带上,而不是咖啡杯。
##### 1.3 辅助函数
- **CAMshow 函数**:将类激活映射覆盖在图像的灰度变暗版本上。
```matlab
function CAMshow(im,CAM)
imSize = size(im);
CAM = imresize(CAM,imSize(1:2));
CAM = normalizeImage(CAM);
CAM(CAM<0.2) = 0;
cmap = jet(255).*linspace(0,1,255)';
CAM = ind2rgb(uint8(CAM*255),cmap)*255;
combinedImage = double(rgb2gray(im))/2 + CAM;
combinedImage = normalizeImage(combinedImage)*255;
imshow(uint8(combinedImage));
end
```
- **normalizeImage 函数**:对图像进行归一化处理。
```matlab
function N = normalizeImage(I)
minimum = min(I(:));
maximum = max(I(:));
N = (I-minimum)/(maximum-minimum);
end
```
- **activationLayerName 函数**:返回提取激活值的层的名称。
```matlab
function layerName = activationLayerName(netName)
if netName == "squeezenet"
layerName = 'relu_conv10';
elseif netName == "googlenet"
layerName = 'inception_5b-output';
elseif netName == "resnet18"
layerName = 'res5b_relu';
elseif netName == "mobilenetv2"
layerName = 'out_relu';
end
end
```
#### 2. 使用 t - SNE 查看网络行为
t - SNE(t - distributed stochastic neighbor embedding)是一种将高维数据映射到二维的技术,通过可视化网络激活值,我们可以更好地理解网络的工作原理。
##### 2.1 t - SNE 的作用
- **可视化数据表示变化**:可以可视化深度学习网络在数据通过各层时如何改变输入数据的表示。
- **发现输入数据问题**:帮助发现输入数据中的问题,理解网络对哪些观测值分类错误。例如,t - SNE 可以将 softmax 层的多维激活值降维为二维表示,结果中的紧密聚类对应网络通常能正确分类的类别。通过可视化,我们可以找到出现在错误聚类中的点,这些点可能是标签错误的观测值,或者是因为与其他类别的观测值相似而被网络误分类。
##### 2.2 具体操作步骤
- **下载数据集**:使用包含 978
0
0
复制全文
相关推荐








