#include "train_classifier.h"
#include <fstream>
#define DEBUG_OUT 1
#define ERROR 1
TrainClassifier::TrainClassifier()
{
//正样本数:
_positive_sample_num = 0;
//负样本数:
_negative_sample_num = 0;
//hog描述子参数:
_nbins = 9; //hog梯度统计方向
_block_size = Size(16, 16); //滑动块大小
_block_stride = Size(8, 8); //滑动歩长
_cell_size = Size(8, 8); //hog统计块大小
}
int TrainClassifier::getHogDescriptorLength(Size image_size)
{
return (size_t)_nbins *
(_block_size.width / _cell_size.width) *
(_block_size.height / _cell_size.height) *
((image_size.width - _block_size.width) / _block_stride.width + 1) *
((image_size.height - _block_size.height) / _block_stride.height + 1);
}
int TrainClassifier::generateClassifier(string save_classifier_path, string positive_sample_path,string negative_sample_path, Size image_size)
{
_hog_descriptor_length = getHogDescriptorLength(image_size);
_image_size = image_size;
//获取样本数:
_positive_sample_num = getSampleNum(positive_sample_path);
_negative_sample_num = getSampleNum(negative_sample_path);
//样本总数:
int total_sample_num = _positive_sample_num + _negative_sample_num;
_sample_feature_mat = cvCreateMat(total_sample_num , _hog_descriptor_length, CV_32FC1);
cvSetZero(_sample_feature_mat);
//64*128的训练样本,该矩阵将是totalSample*3780,64*64的训练样本,该矩阵将是totalSample*1764
_sample_label_mat = cvCreateMat(total_sample_num, 1, CV_32FC1);//样本标识
cvSetZero(_sample_label_mat);
//训练正样本:
train(positive_sample_path, 1);
//训练负样本:
train(negative_sample_path, -1);
//生成分类器:
generator(save_classifier_path);
return 0;
}
//获取样本数:
int TrainClassifier::getSampleNum(string sample_path)
{
ifstream in_f;
in_f.open(sample_path.c_str(), ios_base::in);
if(in_f.is_open() == false)
{
#if DEBUG_OUT && ERROR
cout<<"can't find file!"<<endl;
#endif //DEBUG_OUT && ERROR
return 1;
}
int sample_num = 0;
string temp = "";
while(getline(in_f, temp))
{
sample_num++;
}
in_f.close();
return sample_num;
}
//训练正负样本:
int TrainClassifier::train(string sample_path, int label)
{
//从文件中读取数据:
ifstream in_f;
in_f.open(sample_path.c_str(), ios_base::in);
if(in_f.is_open() == false)
{
#if DEBUG_OUT && ERROR
cout<<"can't find file!"<<endl;
#endif //DEBUG_OUT && ERROR
return 1;
}
string image_path;
static int i = 0;
while(getline(in_f, image_path))
{
cv::Mat image = imread(image_path);
if(image.data != NULL)
{
cv::resize(image, image, _image_size, 0, 0, CV_INTER_LINEAR);
cv::HOGDescriptor hog(_image_size, _block_size, _block_stride, _cell_size, _nbins);
vector<float> featureVec;
hog.compute(image, featureVec, _cell_size);
int featureVecSize = featureVec.size();
for (int j=0; j<featureVecSize; j++)
{
CV_MAT_ELEM( *_sample_feature_mat, float, i, j ) = featureVec[j];
}
_sample_label_mat->data.fl[i] = label;
i++;
}
}
in_f.close();
return 0;
}
//产生分类器:
int TrainClassifier::generator(string save_classifier_path)
{
CvSVMParams params;
params.svm_type = CvSVM::C_SVC;
params.kernel_type = CvSVM::LINEAR;
params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, FLT_EPSILON);
params.C = 0.01;
CvSVM svm;
svm.train( _sample_feature_mat, _sample_label_mat, NULL, NULL, params ); //用SVM线性分类器训练
svm.save(save_classifier_path.c_str(), 0);
cvReleaseMat(&_sample_feature_mat);
cvReleaseMat(&_sample_label_mat);
int supportVectorSize = svm.get_support_vector_count();
#if DEBUG_OUT
cout<<"support num :"<<supportVectorSize<<endl;
#endif //DEBUG_OUT
return 0;
}