这段代码是DetectionOut层的实现,表示怎么从PriorBox、loc、conf三个层得到检测框的。
源码如下:
detection_output_layer.cpp
#include <algorithm>
#include <fstream> // NOLINT(readability/streams)
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "boost/filesystem.hpp"
#include "boost/foreach.hpp"
#include "caffe/layers/detection_output_layer.hpp"
namespace caffe {
// DetectionOutput层的bottom分别是:loc、conf、prior
// 从prototxt中读取配置参数
template <typename Dtype>
void DetectionOutputLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const DetectionOutputParameter& detection_output_param =
this->layer_param_.detection_output_param();
// 必须指定类别数
CHECK(detection_output_param.has_num_classes()) << "Must specify num_classes";
num_classes_ = detection_output_param.num_classes();
// 所有类别共享位置框,默认是true
share_location_ = detection_output_param.share_location();
num_loc_classes_ = share_location_ ? 1 : num_classes_;
// 背景的id
background_label_id_ = detection_output_param.background_label_id();
code_type_ = detection_output_param.code_type();
variance_encoded_in_target_ =
detection_output_param.variance_encoded_in_target();
keep_top_k_ = detection_output_param.keep_top_k();
// 置信度的阈值,如果没设置就为极小值
confidence_threshold_ = detection_output_param.has_confidence_threshold() ?
detection_output_param.confidence_threshold() : -FLT_MAX;
// 非极大值抑制操作时的阈值,应该为一个非负数
// Parameters used in nms.
nms_threshold_ = detection_output_param.nms_param().nms_threshold();
CHECK_GE(nms_threshold_, 0.) << "nms_threshold must be non negative.";
eta_ = detection_output_param.nms_param().eta();
CHECK_GT(eta_, 0.);
CHECK_LE(eta_, 1.);
top_k_ = -1;
if (detection_output_param.nms_param().has_top_k()) {
top_k_ = detection_output_param.nms_param().top_k();
}
// 保存输出值
const SaveOutputParameter& save_output_param =
detection_output_param.save_output_param();
output_directory_ = save_output_param.output_directory();
if (!output_directory_.empty()) {
if (boost::filesystem::is_directory(output_directory_)) {
boost::filesystem::remove_all(output_directory_);
}
if (!boost::filesystem::create_directories(output_directory_)) {
LOG(WARNING) << "Failed to create directory: " << output_directory_;
}
}
output_name_prefix_ = save_output_param.output_name_prefix();
need_save_ = output_directory_ == "" ? false : true;
output_format_ = save_output_param.output_format();
// 需要提供标签文件
if (save_output_param.has_label_map_file()) {
string label_map_file = save_output_param.label_map_file();
if (label_map_file.empty()) {
// Ignore saving if there is no label_map_file provided.
LOG(WARNING) << "Provide label_map_file if output results to files.";
need_save_ = false;
} else {
LabelMap label_map;
CHECK(ReadProtoFromTextFile(label_map_file, &label_map))
<< "Failed to read label map file: " << label_map_file;
CHECK(MapLabelToName(label_map, true, &label_to_name_))
<< "Failed to convert label to name.";
CHECK(MapLabelToDisplayName(label_map, true, &label_to_display_name_))
<< "Failed to convert label to display name.";
}
} else {
need_save_ = false;
}
if (save_output_param.has_name_size_file()) {
string name_size_file = save_output_param.name_size_file();
if (name_size_file.empty()) {
// Ignore saving if there is no name_size_file provided.
LOG(WARNING) << "Provide name_size_file if output results to files.";
need_save_ = false;
} else {
std::ifstream infile(name_size_file.c_str());
CHECK(infile.good())
<< "Failed to open name size file: " << name_size_file;
// The file is in the following format:
// name height width
// ...
string name;
int height, width;
while (infile >> name >> height >> width) {
names_.push_back(name);
sizes_.push_back(std::make_pair(height, width));
}
infile.close();
if (save_output_param.has_num_test_image()) {
num_test_image_ = save_output_param.num_test_image();
} else {
num_test_image_ = names_.size();
}
CHECK_LE(num_test_image_, names_.size());
}
} else {
need_save_ = false;
}
// 对输出再resize
has_resize_ = save_output_param.has_resize_param();
if (has_resize_) {
resize_param_ = save_output_param.resize_param();
}
name_count_ = 0;
// 可视化
visualize_ = detection_output_param.visualize();
if (visualize_) {
// 可视化的阈值设置
visualize_threshold_ = 0.6;
if (detection_output_param.has_visualize_threshold()) {
visualize_threshold_ = detection_output_param.visualize_threshold();
}
data_transformer_.reset(
new DataTransformer<Dtype>(this->layer_param_.transform_param(),
this->phase_));
data_transformer_->InitRand();
save_file_ = detection_output_param.save_file();
}
bbox_preds_.ReshapeLike(*(bottom[0]));
if (!share_location_) {
bbox_permute_.ReshapeLike(*(bottom[0]));
}
conf_permute_.ReshapeLike(*(bottom[1]));
}
template <typename Dtype>
void DetectionOutputLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
if (need_save_) {
CHECK_LE(name_count_, names_.size());
if (name_count_ % num_test_image_ == 0) {
// Clean all outputs.
if (output_format_ == "VOC") {
boost::filesystem::path output_directory(output_directory_);
for (map<int, string>::iterator it = label_to_name_.begin();
it != label_to_name_.end(); ++it) {
if (it->first == background_label_id_) {
continue;
}
std::ofstream outfile;
boost::filesystem::path file(
output_name_prefix_ + it->second + ".txt");
boost::filesystem::path out_file = output_directory / file;
outfile.open(out_file.string().c_str(), std::ofstream::out);
}
}
}
}
// 这里的reshape挺重要的,注意各个数的涵义
CHECK_EQ(bottom[0]->num(), bottom[1]->num());
if (bbox_preds_.num() != bottom[0]->num() ||
bbox_preds_.count(