Weka开发[11]—J48源代码介绍
博客专区 > pior 的博客 > 博客详情
Weka开发[11]—J48源代码介绍
pior 发表于2年前
Weka开发[11]—J48源代码介绍
  • 发表于 2年前
  • 阅读 189
  • 收藏 0
  • 点赞 0
  • 评论 0

华为云·免费上云实践>>>   

这次介绍一下J48的源码,分析J48的源码似乎真还是有用的,同学改造J48写过VFDT,我自己用J48进行特征选择(当然很失败)。

J48buildClassfier函数:

public void buildClassifier(Instances instances) throws Exception {

    ModelSelection modSelection;

 

    if (m_binarySplits)

       modSelection = new BinC45ModelSelection(m_minNumObj, instances);

    else

       modSelection = new C45ModelSelection(m_minNumObj, instances);

    if (!m_reducedErrorPruning)

       m_root = new C45PruneableClassifierTree(modSelection,

!m_unpruned, m_CF, m_subtreeRaising, !m_noCleanup);

    else

       m_root = new PruneableClassifierTree(modSelection, !m_unpruned,

                  m_numFolds, !m_noCleanup, m_Seed);

    m_root.buildClassifier(instances);

    if (m_binarySplits) {

       ((BinC45ModelSelection) modSelection).cleanup();

    } else {

       ((C45ModelSelection) modSelection).cleanup();

    }

}

    NBTree中已经介绍过了,ModelSelection是决定决策树的模型类,前面两个if,一个是判断连续属性是否只分出两个子结点,另一个判断是否最后剪枝。m_root是一个ClassifierTree对象,它调用buildClassifier函数。这里列出这个函数:

public void buildClassifier(Instances data) throws Exception {

    // can classifier tree handle the data?

    getCapabilities().testWithFail(data);

 

    // remove instances with missing class

    data = new Instances(data);

    data.deleteWithMissingClass();

 

    buildTree(data, false);

}

    有注释也没什么好说的,直接看最后一个函数buildTree

public void buildTree(Instances data, boolean keepData) throws Exception {

    Instances[] localInstances;

 

    if (keepData) {

       m_train = data;

    }

    m_test = null;

    m_isLeaf = false;

    m_isEmpty = false;

    m_sons = null;

    m_localModel = m_toSelectModel.selectModel(data);

    if (m_localModel.numSubsets() > 1) {

       localInstances = m_localModel.split(data);

       data = null;

       m_sons = new ClassifierTree[m_localModel.numSubsets()];

       for (int i = 0; i < m_sons.length; i++) {

           m_sons[i] = getNewTree(localInstances[i]);

           localInstances[i] = null;

       }

    } else {

       m_isLeaf = true;

       if (Utils.eq(data.sumOfWeights(), 0))

           m_isEmpty = true;

       data = null;

    }

}

这里的selectModel函数,如果看过NBTree一篇的读者应该不会太陌生,selectModel简单地说就是如果不符合分裂的条件就返回NoSplit,如果符合分裂的条件,则从currentModel数组中选出bestModel返回。

这最要注意的是selectModel也不只是决定哪个属性分裂,其实到底如何分裂已经在这个函数里算里出来了。

    我把selectModel拆开来讲解

// Check if all Instances belong to one class or if not

// enough Instances to split.

checkDistribution = new Distribution(data);

noSplitModel = new NoSplit(checkDistribution);

if (Utils.sm(checkDistribution.total(), 2 * m_minNoObj)

       || Utils.eq(checkDistribution.total(), checkDistribution

              .perClass(checkDistribution.maxClass())))

              return noSplitModel;

    2 * m_minNoObj表示至有有这么多样本才可以分裂,原因很简单,因为一个结点至少分出两个子结点,每个子结点至少有m_minNoObj个样本,第二个是或条件是表示是否这个结点上所有的样本都属于同一类别,也就是这个结点总的权重是否等于这个最多类别的权重。

// Check if all attributes are nominal and have a lot of values.

if (m_allData != null) {

    Enumeration enu = data.enumerateAttributes();

    while (enu.hasMoreElements()) {

       attribute = (Attribute) enu.nextElement();

       if ((attribute.isNumeric())

           || (Utils.sm((double) attribute.numValues(),

              (0.3 * (double) m_allData.numInstances())))) {

                     multiVal = false;

                     break;

                  }

       }

}

    判断是否有很多不同的属性值,标准就是如果有一个属性的属性值小多于总样本数*0.3,那么就是不是multiVal

currentModel = new C45Split[data.numAttributes()];

sumOfWeights = data.sumOfWeights();

 

// For each attribute.

for (i = 0; i < data.numAttributes(); i++) {

 

    // Apart from class attribute.

    if (i != (data).classIndex()) {

 

    // Get models for current attribute.

    currentModel[i] = new C45Split(i, m_minNoObj, sumOfWeights);

    currentModel[i].buildClassifier(data);

 

    // Check if useful split for current attribute

    // exists and check for enumerated attributes with

    // a lot of values.

    if (currentModel[i].checkModel())

       if (m_allData != null) {

           if ((data.attribute(i).isNumeric())

              || (multiVal || Utils.sm((double) data

                  .attribute(i).numValues(),

                   (0.3 * (double) m_allData.numInstances())))) {

                         averageInfoGain = averageInfoGain

                                + currentModel[i].infoGain();

                         validModels++;

                     }

              } else {

                     averageInfoGain = averageInfoGain

                         + currentModel[i].infoGain();

                     validModels++;

               }

        } else

           currentModel[i] = null;

}

    里面重要的两句就是:

// Get models for current attribute.

currentModel[i] = new C45Split(i, m_minNoObj, sumOfWeights);

currentModel[i].buildClassifier(data);

    其它的也没有什么,求一下averageInfoGainvalidModelscheckModel如果可以分出子结点则为真。

这里是C45Split类的成员函数buildClassfier被调用,列出它的代码:

public void buildClassifier(Instances trainInstances) throws Exception {

    // Initialize the remaining instance variables.

    m_numSubsets = 0;

    m_splitPoint = Double.MAX_VALUE;

    m_infoGain = 0;

    m_gainRatio = 0;

 

    // Different treatment for enumerated and numeric

    // attributes.

    if (trainInstances.attribute(m_attIndex).isNominal()) {

        m_complexityIndex = trainInstances.attribute(m_attIndex)

.numValues();

      m_index = m_complexityIndex;

      handleEnumeratedAttribute(trainInstances);

    }else{

      m_complexityIndex = 2;

      m_index = 0;

      trainInstances.sort(trainInstances.attribute(m_attIndex));

      handleNumericAttribute(trainInstances);

    }

  }  

这里handleEnumerateAttributehandleNumericAttribute是决定到底是哪一个属性分裂(m_attIndex)和分裂出几个子结点的函数(m_numSubsets)。这里的m_comlexity就是指分可以分裂出多少子结点。如果是连续属性就是2。再看一下handleEnumeratedAttribute函数:

private void handleEnumeratedAttribute(Instances trainInstances)

throws Exception {

 

    Instance instance;

 

    m_distribution = new Distribution(m_complexityIndex,

       trainInstances.numClasses());

 

    // Only Instances with known values are relevant.

    Enumeration enu = trainInstances.enumerateInstances();

    while (enu.hasMoreElements()) {

        instance = (Instance) enu.nextElement();

        if (!instance.isMissing(m_attIndex))

           m_distribution.add((int) instance.value(m_attIndex),

              instance);

    }

 

    // Check if minimum number of Instances in at least two

    // subsets.

    if (m_distribution.check(m_minNoObj)) {

       m_numSubsets = m_complexityIndex;

       m_infoGain = infoGainCrit.splitCritValue(m_distribution,

              m_sumOfWeights);

       m_gainRatio = gainRatioCrit.splitCritValue(m_distribution,

              m_sumOfWeights, m_infoGain);

    }

}

// Current attribute is a numeric attribute.

m_distribution = new Distribution(2, trainInstances.numClasses());

 

// Only Instances with known values are relevant.

Enumeration enu = trainInstances.enumerateInstances();

i = 0;

while (enu.hasMoreElements()) {

    instance = (Instance) enu.nextElement();

    if (instance.isMissing(m_attIndex))

       break;

    m_distribution.add(1, instance);

    i++;

}

 

firstMiss = i;

       已经讲过了,如果是连续属性就分出两个子结点,也就是Distribution的第一个参数。枚举所有样本,因为在调用HandleNumericAttribute之间已经对数据集根据m_attIndex排序过,所以缺失数据都在最后。也就是firstMiss是在m_attIndex上有确定值的样本个数+1。在while循环中,把所有的样本都先放到bag 1(add(1,instance))。还是列出来一下吧。

public final void add(int bagIndex, Instance instance) throws Exception {

    int classIndex;

    double weight;

 

    classIndex = (int) instance.classValue();

    weight = instance.weight();

    m_perClassPerBag[bagIndex][classIndex] =

       m_perClassPerBag[bagIndex][classIndex] + weight;

    m_perBag[bagIndex] = m_perBag[bagIndex] + weight;

    m_perClass[classIndex] = m_perClass[classIndex] + weight;

    totaL = totaL + weight;

}

       也就这个函数也就是根据参数bagIndex和样本的类别值classIndex,三个成员变量m_perBag, m_perClass, m_perClassPerBag分别加上样本的权重。

// Compute minimum number of Instances required in each subset.

minSplit = 0.1 * (m_distribution.total())

       / ((double) trainInstances.numClasses());

if (Utils.smOrEq(minSplit, m_minNoObj))

    minSplit = m_minNoObj;

else if (Utils.gr(minSplit, 25))

    minSplit = 25;

 

// Enough Instances with known values?

if (Utils.sm((double) firstMiss, 2 * minSplit))

    return;

       计算分最小分裂需要的样本数,这些涉及的值在Quinlan的论文中没有提到,可能也没有太多的道理,就是如果样本数的1/10小于m_minNoObj那么最小分裂样本数就是m_minNoObj,如果大于25,最小分裂样本数就是25

       如果firstMiss小于2*minSplit表示已经不可以再分裂了(为什么刚才已经讲过了)。

// Compute values of criteria for all possible split indices.

defaultEnt = infoGainCrit.oldEnt(m_distribution);

while (next < firstMiss) {

 

    if (trainInstances.instance(next - 1).value(m_attIndex)

       + 1e-5 < trainInstances.instance(next).value(m_attIndex)) {

 

       // Move class values for all Instances up to next

       // possible split point.

       m_distribution.shiftRange(1, 0, trainInstances, last, next);

 

        // Check if enough Instances in each subset and compute

       // values for criteria.

       if (Utils.grOrEq(m_distribution.perBag(0), minSplit)

           && Utils.grOrEq(m_distribution.perBag(1), minSplit)) {

           currentInfoGain = infoGainCrit.splitCritValue(

              m_distribution, m_sumOfWeights, defaultEnt);

           if (Utils.gr(currentInfoGain, m_infoGain)) {

              m_infoGain = currentInfoGain;

              splitIndex = next - 1;

           }

           m_index++;

       }

       last = next;

    }

    next++;

}

       oldEnt计算没有分裂的信息增益,得到defaultEnt注意,刚才是把样本放在了一个bag中。然后对所有有确定值的样本进行循环。第一个if,如果两个属性值太接近,那么选择的分裂点不会有太大的区别,就不进行处理。shiftRange是把第一个bag中下标从lastnext-1的样本移到第0bagshiftRange代码如下:

public final void shiftRange(int from, int to, Instances source,

       int startIndex, int lastPlusOne) throws Exception {

    int classIndex;

    double weight;

    Instance instance;

    int i;

 

    for (i = startIndex; i < lastPlusOne; i++) {

       instance = (Instance) source.instance(i);

       classIndex = (int) instance.classValue();

       weight = instance.weight();

       m_perClassPerBag[from][classIndex] -= weight;

       m_perClassPerBag[to][classIndex] += weight;

       m_perBag[from] -= weight;

       m_perBag[to] += weight;

    }

}

       很简单就是把对应样本的样本权重从from bag中减去,再加到to bag中。

       转回来,如果bag 1bag 0都满足最小分裂样本数,计算在当前分裂点上的信息增益值。如果比上一个最好的分裂点的信息增益高,那么记录下当前的信息增益值为最高信息增益值m_infoGain,和当前分裂点splitIndex

// Was there any useful split?

if (m_index == 0)

    return;

 

// Compute modified information gain for best split.

m_infoGain = m_infoGain - (Utils.log2(m_index) / m_sumOfWeights);

if (Utils.smOrEq(m_infoGain, 0))

    return;

 

// Set instance variables' values to values for best split.

m_numSubsets = 2;

m_splitPoint = (trainInstances.instance(splitIndex + 1).value(

    m_attIndex) + trainInstances.instance(splitIndex).value(

    m_attIndex)) / 2;

       如果没有找到任何分裂点,返回,接下来的m_infoGain自己到J.R.QuinlanImproved use of continuous Attributes in C4.5论文中的第4页第二段中找。最后设置有两个结点,分裂点在刚才找到的最好的分裂点与下一个属性值的中点。

// In case we have a numerical precision problem we need to choose the

// smaller value

if (m_splitPoint == trainInstances.instance(splitIndex + 1).value(

    m_attIndex)) {

    m_splitPoint = trainInstances.instance(splitIndex).value(

       m_attIndex);

}

 

// Restore distributioN for best split.

m_distribution = new Distribution(2, trainInstances.numClasses());

m_distribution.addRange(0, trainInstances, 0, splitIndex + 1);

m_distribution.addRange(1, trainInstances, splitIndex + 1, firstMiss);

 

// Compute modified gain ratio for best split.

m_gainRatio = gainRatioCrit.splitCritValue(m_distribution,

       m_sumOfWeights, m_infoGain);

       if是处理精度的细节问题。然后重新通过addRange计算m_distribution,最后计算增益率(Gain Ratio)

 

这里看到又有一个新类Distribution类,还是要把Distribution类讲一下,Distribution类中有一个bag成员变量,它的意思是能有几个子结点。从下面的构造函数看出来的,第一个参数在上面调用它的时候用的就是m_complexityIndex.

public Distribution(int numBags, int numClasses) {

    int i;

 

    m_perClassPerBag = new double[numBags][0];

    m_perBag = new double[numBags];

    m_perClass = new double[numClasses];

    for (i = 0; i < numBags; i++)

       m_perClassPerBag[i] = new double[numClasses];

    totaL = 0;

}

Distributionadd函数就是在相应的属性值上进行统计,太简单了,略过。

回到刚才的buildTree函数,如果numSubsets返回1,则表示当前结点不再分裂为叶子结点,如果大于1,那么调用split函数,split函数只是根据有上次得到的子结点数,并根据WhichSubset返回值,把当前结点的样本分到几个子结点去。再对每一个子结点训练一个新子树,到这已经与以前讲的ID3有很大的相似了。

可能大家学习的时候都对理论很感兴趣,但看了半天也没看到,有点不解,其实也很好找,当然应该在handleEnumerateAttributehandleNumericAttribute中了,也就是InfoGainSplitCritGainRatioSplitCrit两个类。

分裂一个样本与NBTree相似,这里不再赘述。


共有 人打赏支持
粉丝 25
博文 151
码字总数 22496
×
pior
如果觉得我的文章对您有用,请随意打赏。您的支持将鼓励我继续创作!
* 金额(元)
¥1 ¥5 ¥10 ¥20 其他金额
打赏人
留言
* 支付类型
微信扫码支付
打赏金额:
已支付成功
打赏金额: