当前位置:文档之家› K近邻算法(KNN)的C++实现

K近邻算法(KNN)的C++实现

本文不对KNN算法做过多的理论上的解释,主要是针对问题,进行算法的设计和代码的注解。

KNN算法:优点:精度高、对异常值不敏感、无数据输入假定。

缺点:计算复杂度高、空间复杂度高。

适用数据范围:数值型和标称性。

工作原理:存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一个数据与所属分类的对应关系。

输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。

一般来说,我们只选择样本数据及中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k选择不大于20的整数。

最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。

K-近邻算法的一般流程:(1)收集数据:可以使用任何方法(2)准备数据:距离计算所需要的数值,最好是结构化的数据格式(3)分析数据:可以使用任何方法(4)训练算法:此步骤不适用k-邻近算法(5)测试算法:计算错误率(6)使用算法:首先需要输入样本数据和结构化的输出结果,然后运行k-近邻算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理。

问题一:现在我们假设一个场景,就是要为左边上的点进行分类,如下图所示:上图一共12个左边点,每个坐标点都有相应的坐标(x,y)以及它所属的类别A/B,那么现在需要做的就是给定一个点坐标(x1,y1),判断它属于的类别A或者B。

所有的坐标点在data.txt文件中:0.0 1.1 A1.0 1.0 A2.0 1.0 B0.5 0.5 A2.5 0.5 B0.0 0.0 A1.0 0.0 A2.0 0.0 B3.0 0.0 B0.0 -1.0 A1.0 -1.0 A2.0 -1.0 Bstep1:通过类的默认构造函数去初始化训练数据集dataSet和测试数据testData。

step2:用get_distance()来计算测试数据testData和每一个训练数据dataSet[index]的距离,用map_index_dis来保存键值对<index,distance>,其中index代表第几个训练数据,distance代表第index个训练数据和测试数据的距离。

step3:将map_index_dis按照value值(即distance值)从小到大的顺序排序,然后取前k个最小的value值,用map_label_freq来记录每一个类标签出现的频率。

step4:遍历map_label_freq中的value值,返回value最大的那个key 值,就是测试数据属于的类。

看一下代码KNN_:#include<iostream>#include<map>#include<vector>#include<stdio.h>#include<cmath>#include<cstdlib>#include<algorithm>#include<fstream>using namespace std;typedef char tLabel;typedef double tData;typedef pair<int,double> PAIR;constintcolLen = 2;constintrowLen = 12;ifstream fin;ofstreamfout;class KNN{private:tDatadataSet[rowLen][colLen];tLabel labels[rowLen];tDatatestData[colLen];int k;map<int,double>map_index_dis;map<tLabel,int>map_label_freq;doubleget_distance(tData *d1,tData *d2);public:KNN(int k);voidget_all_distance();voidget_max_freq_label();structCmpByValue{bool operator() (const PAIR&lhs,const PAIR&rhs){returnlhs.second<rhs.second;}};};KNN::KNN(int k){this->k = k;fin.open("data.txt");if(!fin){cout<<"can not open the file data.txt"<<endl;exit(1);}/* input the dataSet */for(inti=0;i<rowLen;i++){for(int j=0;j<colLen;j++){fin>>dataSet[i][j];}fin>>labels[i];}cout<<"please input the test data :"<<endl;/* inuput the test data */for(inti=0;i<colLen;i++)cin>>testData[i];}/** calculate the distance between test data and dataSet[i]*/double KNN:: get_distance(tData *d1,tData *d2){double sum = 0;for(inti=0;i<colLen;i++){sum += pow( (d1[i]-d2[i]) , 2 );}// cout<<"the sum is = "<<sum<<endl;returnsqrt(sum);}/** calculate all the distance between test data and each training data*/void KNN:: get_all_distance(){double distance;inti;for(i=0;i<rowLen;i++){distance = get_distance(dataSet[i],testData);//<key,value> =><i,distance>map_index_dis[i] = distance;}//traverse the map to print the index and distancemap<int,double>::const_iterator it = map_index_dis.begin();while(it!=map_index_dis.end()){cout<<"index = "<<it->first<<" distance = "<<it->second<<endl;it++;}}/** check which label the test data belongs to to classify the test data*/void KNN:: get_max_freq_label(){//transform the map_index_dis to vec_index_disvector<PAIR>vec_index_dis( map_index_dis.begin(),map_index_dis.end() );//sort the vec_index_dis by distance from low to high to get the nearest datasort(vec_index_dis.begin(),vec_index_dis.end(),CmpByValue());for(inti=0;i<k;i++){cout<<"the index = "<<vec_index_dis[i].first<<" the distance ="<<vec_index_dis[i].second<<" the label = "<<labels[vec_index_dis[i].first]<<" the coordinate ( "<<dataSet[ vec_index_dis[i].first ][0]<<","<<dataSet[ vec_index_dis[i].first ][1]<<" )"<< endl;//calculate the count of each labelmap_label_freq[ labels[ vec_index_dis[i].first ] ]++;}map<tLabel,int>::const_iteratormap_it = map_label_freq.begin();tLabel label;intmax_freq = 0;//find the most frequent labelwhile(map_it != map_label_freq.end() ){if(map_it->second >max_freq ){max_freq = map_it->second;label = map_it->first;}map_it++;}cout<<"The test data belongs to the "<<label<<" label"<<endl;}int main(){int k ;cout<<"please input the k value : "<<endl;cin>>k;KNN knn(k);knn.get_all_distance();knn.get_max_freq_label();system("pause");return 0;}我们来测试一下这个分类器(k=5):testData(5.0,5.0):testData(-5.0,-5.0):testData(1.6,0.5):分类结果的正确性可以通过坐标系来判断,可以看出结果都是正确的。

相关主题