文档章节

libSVM源码分析

y
 yunpiao
发布于 2014/12/26 09:59
字数 2676
阅读 66
收藏 0

转载请注明原载地址:http://blog.csdn.net/xinhanggebuguake/article/details/8705648 

 

在此之前,上海交大模式分析与机器智能实验室对2.6版本的svm.cpp做了部分注解,《LibSVM学习(四)——逐步深入LibSVM》也介绍了libSVM的思路,很精彩。而我写这篇博客更侧重与理解算法流程与具体代码的结合点。(环境:LibSVM2.6  C-SVCSVM   RBF核函数)

函数调用流程:

svm-train.c

main()

parse_command_line();//解析命令行,将数据读入param,并获取input filemodel file

   read_problem();//读取input file中的数据到prob中。

   (do_cross_validation();//该函数将试验所有的核函数,根据交叉验证选择最好的核函数)

   svm_train(&prob,&param);

      ->统计classes的数量以及每个classes下样本数量

      ->把相同类别的训练数据分组,每个分组开始的索引记录在start数组里。

      ->计算每个类别的惩罚因子C

      ->训练k*(k-1)/2个分类器模型

      ->svm_train_one();

         ->solve_c_svc();

            ->s.Solve();

                 ->初始化alpha_statusactive_setactive_size                  

                  ->求梯度

迭代优化:          ->do_shrinking(); //把数据分成active_sizeactive_size-L的部分集中排序。

                ->select_working_set(); //选择两个样本

                 ->更新alpha[i]alpha[j]的值

                 ->更新GG_Bar

                 ->calculate_rho();//计算b

->计算目标值

   svm_save_model(model_file_name,model);

   svm_destroy_model(model);

   svm_destroy_param(&param);

1、read_problem()

prob.y //记录每行样本所属类别

prob.x //指针数组(L),每个指针指向x_sapce(实际存储特征词)的一维

x_space//实际的存储结构,记录所有样本的特征词(L*(k+1)个),可以形象化为L维,虽然每一维的长度可能不同。

prob.yprob.xx_space的关系如下图所示:

2、统计classes数据

使用以下变量遍历所有样本,统计数据。

label[i] //记录类别

count[i] //记录类别中样本的数量

index[i] //记录位置为i的样本的类别

nr_class //索引类别的数目

3、训练数据分组

训练数据进行分组时使用到了以下数据:

int *start = Malloc(int,nr_class);

svm_node **x = Malloc(svm_node *,l);

两者之间通过index进行过渡,因为index记录了位置i的样本的类别,每个类别在start中只有一个位置,即该类别在x中的起始的索引。x是各类的排列顺序是按照原始样本中各类出现的先后顺序排列的,prob中则是原始样本的label序号排列,而start中记录的是各类的起始序号,而这个序号是在x的序号。

4、训练k*(k-1)/2个分类器模型

svm对于多类别的分类方法有多种,但将实现分为两个过程:训练阶段,判别阶段。

11-V-R方式

   对于k类问题,把其中某一类的n个训练样本视为一类,所有其他类别归为另一类,因此共有k个分类器。最后,判别使用竞争方式,也就是哪个类得票多就属于那个类。

21-V-1方式

   one-against-one方式。该方法把其中的任意两类构造一个分类器,共有(k-1)×k/2个分类器。最后判别也采用竞争方式。

31-V-1libSVM中的实现

    LibSVM采用的是1-V-1方式,因为这种方式思路简单,并且许多实践证实效果比1-V-R方式要好。该方法在训练阶段采用1-V-1方式,而判别阶段采用一种两向有向无环图的方式。

训练阶段:

 

上图是一个51-V-1组合的示意图,红色是0类和其他类的组合,紫色是1类和剩余类的组合,绿色是2类与右端两类的组合,蓝色只有34的组合。因此,对于nr_class个类的组合方式为:

[cpp]  view plain copy
<EMBED id=ZeroClipboardMovie_1 height=18 name=ZeroClipboardMovie_1 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=1&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. for(i = 0; i < nr_class; i ++)  
  2. {  
  3.     for(j = i+1; i < nr_class; j ++)        
  4.     {   
  5.        类 i –V – 类 j  
  6.     }   
  7. }  

判别阶段:

在对一篇文章进行分类之前,我们先按照下面图的样子来组织分类器(如你所见,这是一个有向无环图,因此这种方法也叫做DAG-SVM

在分类时,我们就可以先问分类器“15”(意思是它能够回答是第1类还是第5),如果它回答5,我们就往左走,再问“25”这个分类器,如果它还说是“5”,我们就继续往左走,这样一直问下去,就可以得到分类结果。

5、计算梯度

主要代码如下:

[cpp]  view plain copy
<EMBED id=ZeroClipboardMovie_2 height=18 name=ZeroClipboardMovie_2 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=2&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. G[i] = b[i];   
  2. G_bar[i] = 0;  
  3. Qfloat *Q_i = Q.get_Q(i,l);   
  4. for(i=0;i<l;i++)  
  5. {  
  6.     for(j=0;j<l;j++)  
  7.        G[j] += alpha_i*Q_i[j];  
  8.     for(j=0;j<l;j++)  
  9.        G_bar[j] += get_C(i) * Q_i[j];  
  10. }  

首先,Q.get_Q(i,l)返回data,而

data[j] = (Qfloat)(y[i]*y[j]*(this->*kernel_function)(i,j));

翻译成公式,即:

                              

所以,以上计算梯度的代码翻译成公式,则:

G为:

             (5.1)

G_bar为:

                     (5.2)   

6、数据选择select_working_set(i,j) 

理论依据:     

   对于样本数量比较多的时候(几千个),SVM所需要的内存是计算机所不能承受的。目前,对于这个问题的解决方法主要有两种:块算法和分解算法。这里,libSVM采用的是分解算法中的SMO(串行最小化)方法,其每次训练都只选择两个样本。我们不对SMO做具体的讨论,要想深入了解可以查阅相关的资料,这里只谈谈和程序有关的知识。

   一般SVM的对偶问题为:

       

 S.t.                                     6.1

                                                                                              

SVM收敛的充分必要条件是KKT条件,其表现为:

              6.2

6.1式求导可得:

                                                   (6.3 

进一步推导可知:

                                                            6.4

也就是说,只要所有的样本都满足6.4式,那么得到解就是最优值。因此,在每轮训练中,每次只要选择两个样本(序号为ij),是最违反KKT条件(也就是6.4式)的样本,就能保证其他样本也满足KKT条件。序号ij的选择方式如下: 

                                            6.5

libSVM实现:

由公式5.16.5可知,select_working_set的过程,只跟G_barC有关,所以根据is_lower_boundis_upper_bound判断C的范围,再根据y[i],可以将公式6.5分为8个分支。循环遍历所有样本,就能查找到最违反KTT条件的样本的index

7、数据缩放do_shrinking()

   上面说到SVM用到的内存巨大,另一个缺陷就是计算速度,因为数据大了,计算量也就大,很显然计算速度就会下降。因此,一个好的方式就是在计算过程中逐步去掉不参与计算的数据。因为,实践证明,在训练过程中,alpha[i]一旦达到边界(alpha[i]=0或者alpha[i]=C),alpha[i]值就不会变,随着训练的进行,参与运算的样本会越来越少,SVM最终结果的支持向量(0<alpha[i]<C)往往占很少部分。

   LibSVM采用的策略是在计算过程中,检测active_size中的alpha[i]值,如果alpha[i]到了边界,那么就应该把相应的样本去掉(变成inactived),并放到栈的尾部,从而逐步缩小active_size的大小。

8、迭代优化停止准则

    LibSVM程序中,停止准则蕴含在了函数select_working_set(i,j)返回值中。也就是,当找不到符合6.5式的样本时,那么理论上就达到了最优解。但是,实际编程时,由于KKT条件还是蛮苛刻的,要进行适当的放松。令: 

                                    8.1

6.4式可知,当所有样本都满足KKT条件时,gi ≤ -gj

加一个适当的宽松范围ε,也就是程序中的eps,默认为0.001,那么最终的停止准则为:

                     gi ≤ -gj +ε  →    gi + gj ≤ε

9、因子α的更新

理论依据:

由于SMO每次都只选择2个样本,那么4.1式的等式约束可以转化为直线约束: 

             9.1

转化为图形表示为: 

 把式9.1α1α表示,即:,结合上图由解析几何可得α2的取值范围: 

9.2

经过一系列变换,可以得到的α2更新值α2new

                                                                              9.3

结合9.29.3式得到α2new最终表达式:

                                                                                              9.4

得到α2new后,就可以由9.1式求α1new

libSVM实现:

具体操作的时候,把选择后的序号ij代替这里的21就可以了。当然,编程时,这些公式还是太抽象。对于9.2式,还需要具体细分。比如,对于y1y2=-1时的L = max(0,α2- α1),是0大还α2- α1是大的问题。总共需要分8种情况。至于程序中在一个分支中给α1newα2new同时赋值,是因为两者之间存在的关系:

diff = alpha[i] - alpha[j];

依据公式9.4,最内层对alpha[i](alpha[j])判断可以得出alpha[i] (alpha[j])的值,代入以上公式可得另外一个的值。

10、更新GG_Bar

根据的变化更新G(i),更新alpha_status较简单,根据alpha状态前后是否有变化,适当更新,更新的内容参考公式5.2

11、截距b的计算

b计算的基本公式为:

            11.1

理论上,b的值是不定的。当程序达到最优后,只要用任意一个标准支持向量机(0<alpha[i]<C)的样本带入11.1式,得到的b值都是可以的。目前,求b的方法也有很多种。在libSVM中,分别对y=+1y=-1的两类所有支持向量求b,然后取平均值:

                                 

12、计算目标函数值

因为目标值的计算公式为:1/2*alpha*Sigma (G[i]+b[i])

G[i]转换为公式为alpha_i*Q_i[j]+b[i]

由于在传递给Solve函数的minus_ones将所有值赋为-1,所以b[i]=-1,以上公式就转换为

1/2*alpha*Sigma (Q_i[j]) + 1/2*alpha*2*(-1);

上面的公式不正是我们的目标函数吗。所以可以理解libSVM中的实现。


 

总结:由于刚接触相关方面的知识,疏漏之处在所难免,希望各位高手能不吝赐教!


参考资料:

http://blog.csdn.net/xinhanggebuguake/article/details/8705631

http://www.blogjava.net/zhenandaci/archive/2009/03/26/262113.html

http://www.cnblogs.com/biyeymyhjob/archive/2012/07/17/2591592.html

http://blog.csdn.net/flydreamGG/article/details/4470121

libsvm-2[2].8程序代码导读       刘国平

序列最小化方法                 罗林开

本文转载自:http://blog.csdn.net/xinhanggebuguake/article/details/8705648

共有 人打赏支持
y
粉丝 2
博文 41
码字总数 71902
作品 0
海淀
私信 提问
MATLAB安装libsvm工具箱的方法

支持向量机(support vector machine,SVM)是机器学习中一种流行的学习算法,在分类与回归分析中发挥着重要作用。基于SVM算法开发的工具箱有很多种,下面我们要安装的是十分受欢迎的libsvm工...

东聃
2018/08/12
0
0
LibSvm使用说明和LibSvm源码解析

kernel_type rbf //训练采用的核函数类型,此处为RBF核gamma 0.0769231 //RBF核的参数γnr_class 2 //类别数,此处为两分类问题total_sv 132 //支持向量总个数rho 0.424462 //判决函数的偏置...

haoji007
2018/05/13
0
0
sparkmlib的sample_binary_classification_data.txt 和sample_libsvm_data.txt内容怎么换成实际项目内容

sample_binary_classification_data.txt 和sample_libsvm_data.txt的内容怎么理解和使用 sparkmlib的sample_binary_classification_data.txt 和sample_libsvm_data.txt内容怎么换成实际项目内......

知行合一1
2017/09/30
20
0
【毕设进行时-工业大数据,数据挖掘】LIBSVM 初步测试

正文之前 打摆子的日子很快就要一去不复返了。想想有点悲伤。今天做了下LibSVM的初步运用,也写了个从数据库读取数据,然后改造成LibSVM需要的数据格式的类,需要的自取。 正文 这是个从前面...

HustWolf
2018/04/23
0
0
python下使用libsvm:计算点到超平面的距离

最近在看的资料里涉及到计算 点到支持向量机分类超平面的距离 这一点内容,我使用的svm是libsvm。 由于是新手,虽然看了一些资料,但中英转换误差等等原因导致经常出现理解错误,因此对libsv...

小梳子一直走
2014/03/17
0
0

没有更多内容

加载失败,请刷新页面

加载更多

容器服务

简介 容器服务提供高性能可伸缩的容器应用管理服务,支持用 Docker 和 Kubernetes 进行容器化应用的生命周期管理,提供多种应用发布方式和持续交付能力并支持微服务架构。 产品架构 容器服务...

狼王黄师傅
昨天
3
0
高性能应用缓存设计方案

为什么 不管是刻意或者偶尔看其他大神或者大师在讨论高性能架构时,自己都是认真的去看缓存是怎么用呢?认认真真的看完发现缓存这一块他们说的都是一个WebApp或者服务的缓存结构或者缓存实现...

呼呼南风
昨天
12
0
寻找一种易于理解的一致性算法(扩展版)

摘要 Raft 是一种为了管理复制日志的一致性算法。它提供了和 Paxos 算法相同的功能和性能,但是它的算法结构和 Paxos 不同,使得 Raft 算法更加容易理解并且更容易构建实际的系统。为了提升可...

Tiny熊
昨天
3
0
聊聊GarbageCollectionNotificationInfo

序 本文主要研究一下GarbageCollectionNotificationInfo CompositeData java.management/javax/management/openmbean/CompositeData.java public interface CompositeData { public Co......

go4it
昨天
3
0
阿里云ECS的1M带宽理解

本文就给大家科普下阿里云ECS的固定1M带宽的含义。 “下行带宽”和“上行带宽” 为了更好的理解,需要先给大家解释个词“下行带宽”和“上行带宽”: 下行带宽:粗略的解释就是下载数据的最大...

echojson
昨天
10
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部