博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
libsvm代码阅读:关于Kernel类分析
阅读量:5781 次
发布时间:2019-06-18

本文共 6549 字,大约阅读时间需要 21 分钟。

hot3.png

这一篇博文来分析下Kernel类,代码上很简单,一般都能看懂。Kernel类主要是为SVM的核函数服务的,里面实现了SVM常用的核函数,通过函数指针来使用这些核函数。

其中几个常用核函数如下所示:(一般情况下,使用RBF核函数能取得很好的效果)

关于基类QMatrix在Kernel中的作用并不明显,只是定义了一些纯虚函数,Kernel继承这些函数,Kernel只对swap_index进行了定义。其余的get_Q和get_QD在Kernel并没有用到。

[cpp]  
<EMBED id=ZeroClipboardMovie_1 height=18 name=ZeroClipboardMovie_1 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=1&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. class QMatrix {  
  2. public:  
  3.     virtual Qfloat *get_Q(int column, int len) const = 0;//纯虚函数,在子类中实现,important!  
  4.     virtual double *get_QD() const = 0;  
  5.     virtual void swap_index(int i, int j) const = 0;  
  6.     virtual ~QMatrix() {}  
  7. };  

Kernel类的定义函数,比较简单就不细说。

[cpp]  
<EMBED id=ZeroClipboardMovie_2 height=18 name=ZeroClipboardMovie_2 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=2&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. class Kernel: public QMatrix {  
  2. public:  
  3.     Kernel(int l, svm_node * const * x, const svm_parameter& param);  
  4.     virtual ~Kernel();  
  5.   
  6.     static double k_function(const svm_node *x, const svm_node *y,  
  7.                  const svm_parameter& param);  
  8.     virtual Qfloat *get_Q(int column, int len) const = 0;  
  9.     virtual double *get_QD() const = 0;  
  10.     virtual void swap_index(int i, int j) const // no so const...  
  11.     {  
  12.         swap(x[i],x[j]);  
  13.         if(x_square) swap(x_square[i],x_square[j]);  
  14.     }  
  15. protected:  
  16.   
  17.     double (Kernel::*kernel_function)(int i, int j) const;  
  18.   
  19. private:  
  20.     const svm_node **x;//用来指向样本数据,每次数据传入时通过克隆函数来实现,完全重新分配内存,主要是为处理多类着想  
  21.     double *x_square;//使用RBF 核才使用  
  22.   
  23.     // svm_parameter  
  24.     const int kernel_type;  
  25.     const int degree;  
  26.     const double gamma;  
  27.     const double coef0;  
  28.   
  29.     static double dot(const svm_node *px, const svm_node *py);  
  30.   
  31.     double kernel_linear(int i, int j) const  
  32.     {  
  33.         return dot(x[i],x[j]);  
  34.     }  
  35.     double kernel_poly(int i, int j) const  
  36.     {  
  37.         return powi(gamma*dot(x[i],x[j])+coef0,degree);  
  38.     }  
  39.     double kernel_rbf(int i, int j) const  
  40.     {  
  41.         return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));  
  42.     }  
  43.   
  44.     double kernel_sigmoid(int i, int j) const  
  45.     {  
  46.         return tanh(gamma*dot(x[i],x[j])+coef0);  
  47.     }  
  48.     double kernel_precomputed(int i, int j) const  
  49.     {  
  50.         return x[i][(int)(x[j][0].value)].value;  
  51.     }  
  52. };  

这个Kernel类的函数比较清晰,我直接把它的全部代码贴出。

全部代码如下:

[cpp]  
<EMBED id=ZeroClipboardMovie_3 height=18 name=ZeroClipboardMovie_3 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=3&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. //  
  2. // Kernel evaluation  
  3. //  
  4. // the static method k_function is for doing single kernel evaluation  
  5. // the constructor of Kernel prepares to calculate the l*l kernel matrix  
  6. // the member function get_Q is for getting one column from the Q Matrix  
  7. //  
  8. class QMatrix {  
  9. public:  
  10.     virtual Qfloat *get_Q(int column, int len) const = 0;  
  11.     virtual double *get_QD() const = 0;  
  12.     virtual void swap_index(int i, int j) const = 0;  
  13.     virtual ~QMatrix() {}  
  14. };  
  15.   
  16. class Kernel: public QMatrix {  
  17. public:  
  18.     Kernel(int l, svm_node * const * x, const svm_parameter& param);//构造函数  
  19.     virtual ~Kernel();  
  20.   
  21.     static double k_function(const svm_node *x, const svm_node *y,  
  22.                  const svm_parameter& param);  
  23.     virtual Qfloat *get_Q(int column, int len) const = 0;  
  24.     virtual double *get_QD() const = 0;  
  25.     virtual void swap_index(int i, int j) const // no so const...  
  26.     {  
  27.         swap(x[i],x[j]);  
  28.         if(x_square) swap(x_square[i],x_square[j]);  
  29.     }  
  30. protected:  
  31.   
  32.     double (Kernel::*kernel_function)(int i, int j) const;  
  33.   
  34. private:  
  35.     const svm_node **x;//用来指向样本数据,每次数据传入时通过克隆函数来实现,完全重新分配内存,主要是为处理多类着想  
  36.     double *x_square;//使用RBF 核才使用  
  37.   
  38.     // svm_parameter  
  39.     const int kernel_type;  
  40.     const int degree;  
  41.     const double gamma;  
  42.     const double coef0;  
  43.   
  44.     static double dot(const svm_node *px, const svm_node *py);  
  45.   
  46.     double kernel_linear(int i, int j) const  
  47.     {  
  48.         return dot(x[i],x[j]);  
  49.     }  
  50.     double kernel_poly(int i, int j) const  
  51.     {  
  52.         return powi(gamma*dot(x[i],x[j])+coef0,degree);  
  53.     }  
  54.     double kernel_rbf(int i, int j) const  
  55.     {  
  56.         return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));  
  57.     }  
  58.   
  59.     double kernel_sigmoid(int i, int j) const  
  60.     {  
  61.         return tanh(gamma*dot(x[i],x[j])+coef0);  
  62.     }  
  63.     double kernel_precomputed(int i, int j) const  
  64.     {  
  65.         return x[i][(int)(x[j][0].value)].value;  
  66.     }  
  67. };  
  68.   
  69. //构造函数,初始化类中的部分常量,指定核函数,克隆样本数据。如果使用RBF核函数,则计算x_square[i]  
  70. Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)  
  71. :kernel_type(param.kernel_type), degree(param.degree),  
  72.  gamma(param.gamma), coef0(param.coef0)  
  73. {  
  74.     switch(kernel_type)  
  75.     {  
  76.         case LINEAR:  
  77.             kernel_function = &Kernel::kernel_linear;  
  78.             break;  
  79.         case POLY:  
  80.             kernel_function = &Kernel::kernel_poly;  
  81.             break;  
  82.         case RBF:  
  83.             kernel_function = &Kernel::kernel_rbf;  
  84.             break;  
  85.         case SIGMOID:  
  86.             kernel_function = &Kernel::kernel_sigmoid;  
  87.             break;  
  88.         case PRECOMPUTED:  
  89.             kernel_function = &Kernel::kernel_precomputed;  
  90.             break;  
  91.     }  
  92.   
  93.     clone(x,x_,l);//void clone(T*& dst, S* src, int n)  
  94.   
  95.     if(kernel_type == RBF)  
  96.     {  
  97.         x_square = new double[l];  
  98.         for(int i=0;i<l;i++)  
  99.             x_square[i] = dot(x[i],x[i]);  
  100.     }  
  101.     else  
  102.         x_square = 0;  
  103. }  
  104.   
  105. Kernel::~Kernel()  
  106. {  
  107.     delete[] x;  
  108.     delete[] x_square;  
  109. }  
  110.   
  111. double Kernel::dot(const svm_node *px, const svm_node *py)  
  112. {  
  113.     double sum = 0;  
  114.     while(px->index != -1 && py->index != -1)  
  115.     {  
  116.         if(px->index == py->index)  
  117.         {  
  118.             sum += px->value * py->value;  
  119.             ++px;  
  120.             ++py;  
  121.         }  
  122.         else  
  123.         {  
  124.             if(px->index > py->index)  
  125.                 ++py;  
  126.             else  
  127.                 ++px;  
  128.         }             
  129.     }  
  130.     return sum;  
  131. }  
  132.   
  133. double Kernel::k_function(const svm_node *x, const svm_node *y,  
  134.               const svm_parameter& param)  
  135. {  
  136.     switch(param.kernel_type)  
  137.     {  
  138.         case LINEAR:  
  139.             return dot(x,y);  
  140.         case POLY:  
  141.             return powi(param.gamma*dot(x,y)+param.coef0,param.degree);  
  142.         case RBF:  
  143.         {  
  144.             double sum = 0;  
  145.             while(x->index != -1 && y->index !=-1)  
  146.             {  
  147.                 if(x->index == y->index)  
  148.                 {  
  149.                     double d = x->value - y->value;  
  150.                     sum += d*d;  
  151.                     ++x;  
  152.                     ++y;  
  153.                 }  
  154.                 else  
  155.                 {  
  156.                     if(x->index > y->index)  
  157.                     {     
  158.                         sum += y->value * y->value;  
  159.                         ++y;  
  160.                     }  
  161.                     else  
  162.                     {  
  163.                         sum += x->value * x->value;  
  164.                         ++x;  
  165.                     }  
  166.                 }  
  167.             }  
  168.   
  169.             while(x->index != -1)  
  170.             {  
  171.                 sum += x->value * x->value;  
  172.                 ++x;  
  173.             }  
  174.   
  175.             while(y->index != -1)  
  176.             {  
  177.                 sum += y->value * y->value;  
  178.                 ++y;  
  179.             }  
  180.               
  181.             return exp(-param.gamma*sum);  
  182.         }  
  183.         case SIGMOID:  
  184.             return tanh(param.gamma*dot(x,y)+param.coef0);  
  185.         case PRECOMPUTED:  //x: test (validation), y: SV  
  186.             return x[(int)(y->value)].value;  
  187.         default:  
  188.             return 0;  // Unreachable   
  189.     }  
  190. }  

转载于:https://my.oschina.net/u/1269935/blog/365385

你可能感兴趣的文章
如何用UPA优化性能?先读懂这份报告!
查看>>
这些Java面试题必须会-----鲁迅
查看>>
Linux 常用命令
查看>>
NodeJS 工程师必备的 8 个工具
查看>>
CSS盒模型
查看>>
ng2路由延时加载模块
查看>>
使用GitHub的十个最佳实践
查看>>
脱离“体验”和“安全”谈盈利的游戏运营 都是耍流氓
查看>>
慎用!BLEU评价NLP文本输出质量存在严重问题
查看>>
基于干净语言和好奇心的敏捷指导
查看>>
Node.js 2017企业用户调查结果发布
查看>>
“软”苹果水逆的一周:杂志服务崩溃,新机型遭泄露,芯片首架离职
查看>>
JAVA的优势就是劣势啊!
查看>>
ELK实战之logstash部署及基本语法
查看>>
帧中继环境下ospf的使用(点到点模式)
查看>>
BeanShell变量和方法的作用域
查看>>
LINUX下防恶意扫描软件PortSentry
查看>>
由数据库对sql的执行说JDBC的Statement和PreparedStatement
查看>>
springmvc+swagger2
查看>>
软件评测-信息安全-应用安全-资源控制-用户登录限制(上)
查看>>