详解聚类算法Kmeans的两大优化——mini-batch和Kmeans++

2019/04/10 10:10
阅读数 88

本文始发于个人公众号:TechFlow,原创不易,求个关注

<br>

<section id="nice" data-tool="mdnice编辑器" data-website="https://www.mdnice.com" style="font-size: 16px; color: black; padding: 0 10px; line-height: 1.6; word-spacing: 0px; letter-spacing: 0px; word-break: break-word; word-wrap: break-word; text-align: left; font-family: Optima-Regular, Optima, PingFangSC-light, PingFangTC-light, 'PingFang SC', Cambria, Cochin, Georgia, Times, 'Times New Roman', serif; margin-top: -10px;"><p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">今天是<strong style="font-weight: bold; color: rgb(71, 193, 168);">机器学习专题的第13篇</strong>文章,我们来看下Kmeans算法的优化。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">在上一篇文章当中我们一起学习了Kmeans这个聚类算法,在算法的最后我们提出了一个问题:Kmeans算法虽然效果不错,但是每一次迭代都需要遍历全量的数据,一旦数据量过大,由于计算复杂度过大迭代的次数过多,会导致<strong style="font-weight: bold; color: rgb(71, 193, 168);">收敛速度非常慢</strong>。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">想想看,如果我们是在面试当中遇到的这个问题,我们事先并不知道正解,我们应该怎么回答呢?</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">还是老套路,我们在回答问题之前,先来分析问题。问题是收敛速度慢,计算复杂度高。计算复杂度高的原因我们也知道了,<strong style="font-weight: bold; color: rgb(71, 193, 168);">一个是因为样本过大,另一个是因为迭代次数过多</strong>。所以显然,我们想要改进这个问题,应该从这两点入手。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">这两点是问题的关键点,针对这两点我们其实可以想出很多种优化和改进的方法。也就是说这是一个开放性问题,相比标准答案,推导和思考问题的思路更加重要。相反,如果我们抓不住关键点,那么回答也会跑偏,这就是为什么我在面试的时候,有些候选人会回答使用分布式系统或者是增加资源加速计算,或者是换一种其他的算法的原因。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">也就是说分析问题和解决问题的思路过程,比解决方法本身更加重要。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">下面,我们就上面提到的两个关键点各介绍一个优化方法。</p> <h2 data-tool="mdnice编辑器" style="font-weight: bold; font-size: 22px; border-bottom: 2px solid rgb(89,89,89); margin-bottom: 50px; margin-top: 100px; color: rgb(89,89,89);"><span class="prefix" style="font-size: 22px; border-bottom: 2px solid rgb(89,89,89); display: none;"></span><span class="content" style="font-size: 22px; display: inline-block; border-bottom: 2px solid rgb(89,89,89);">mini batch</span><span class="suffix" style="font-size: 22px; display: inline-block; border-bottom: 2px solid rgb(89,89,89);"></span></h2> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">mini batch的思想非常朴素,既然全体样本当中数据量太大,会使得我们迭代的时间过长,那么我们<strong style="font-weight: bold; color: rgb(71, 193, 168);">缩小数据规模</strong>行不行?</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">那怎么减小规模呢,很简单,我们随机从整体当中做一个抽样,<strong style="font-weight: bold; color: rgb(71, 193, 168);">选取出一小部分数据来代替整体</strong>。这样我们人为地缩小样本的规模,不就可以提升迭代的速度了?</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">通过抽样我们的确可以提升迭代的效率,但是这样能保证正确性吗?</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">这个问题很好回答,我们只需要简单做个实验就可以证明。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">我们利用上周开发的并没有经过任何优化的代码,并且将生成的样本的数量增加到五万,从下面的这张图我们可以看出,朴素的Kmeans足足用了37.2秒才完成了计算。我们得到的聚类结果如下:</p> <figure data-tool="mdnice编辑器" style="margin: 0; margin-top: 10px; margin-bottom: 10px;"><img src="https://user-gold-cdn.xitu.io/2020/3/25/1710f1d21484a8bb?w=516&h=303&f=jpeg&s=20022" alt style="display: block; margin: 0 auto; width: auto; max-width: 100%;"></figure> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">接着我们通过numpy下的random.choice,<strong style="font-weight: bold; color: rgb(71, 193, 168);">从中随机选择1000条样本</strong>,我们对比一下前后的耗时和结果。</p> <figure data-tool="mdnice编辑器" style="margin: 0; margin-top: 10px; margin-bottom: 10px;"><img src="https://user-gold-cdn.xitu.io/2020/3/25/1710f1d207c7463c?w=483&h=302&f=jpeg&s=18405" alt style="display: block; margin: 0 auto; width: auto; max-width: 100%;"></figure> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">我们再来看下两次聚类的中心,从图片上来看两者<strong style="font-weight: bold; color: rgb(71, 193, 168);">误差极小</strong>,我们打印出坐标来观察,误差在0.05以内,可以说是非常接近了。</p> <figure data-tool="mdnice编辑器" style="margin: 0; margin-top: 10px; margin-bottom: 10px;"><img src="https://user-gold-cdn.xitu.io/2020/3/25/1710f1d21944c0f7?w=511&h=218&f=jpeg&s=20473" alt style="display: block; margin: 0 auto; width: auto; max-width: 100%;"></figure> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">虽然mini batch的原理说穿了一钱不值,但是它的的确确非常重要,不仅重要而且在机器学习领域广为使用。在大数据的场景下,几乎所有模型都需要做mini batch优化。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">但是我们不禁有一个问题,这个方案全靠随机,看起来非常不靠谱,会不会出现我们选出来的结果偏差特别大的情况,比如刚好都在一个簇当中?从理论上来看,这当然是可能的,所以为了谨慎起见,我们<strong style="font-weight: bold; color: rgb(71, 193, 168);">可以重复多次采样</strong>,再对计算到的类簇坐标计算均值,直到簇中心趋于稳定为止。或者可以人工设置迭代次数,直到满足迭代次数要求时停止。</p> <h2 data-tool="mdnice编辑器" style="font-weight: bold; font-size: 22px; border-bottom: 2px solid rgb(89,89,89); margin-bottom: 50px; margin-top: 100px; color: rgb(89,89,89);"><span class="prefix" style="font-size: 22px; border-bottom: 2px solid rgb(89,89,89); display: none;"></span><span class="content" style="font-size: 22px; display: inline-block; border-bottom: 2px solid rgb(89,89,89);">Kmeans ++</span><span class="suffix" style="font-size: 22px; display: inline-block; border-bottom: 2px solid rgb(89,89,89);"></span></h2> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">如果说mini batch是一种通用的方法,并且看起来有些儿戏的话,那么下面要介绍的方法则要硬核许多。这个方法<strong style="font-weight: bold; color: rgb(71, 193, 168);">直接在Kmeans算法本身上做优化</strong>因此被称为Kmeans++。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">前文当中我们已经说过了,想要优化Kmeans算法的效率问题,大概有两个入手点。一个是样本数量太大,另一个是迭代次数过多。刚才我们介绍的mini batch针对的是样本数量过多的情况,Kmeans++的方法则是针对迭代次数。我们通过某种方法<strong style="font-weight: bold; color: rgb(71, 193, 168);">降低收敛需要的迭代次数,从而达到快速收敛的目的</strong>。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">这个思路很明确,但是操作却不简单,迭代次数和收敛效果是相关的。也就是说<strong style="font-weight: bold; color: rgb(71, 193, 168);">在达到收敛之前,迭代次数是不能减少的</strong>,否则就会导致不收敛。而且聚类问题和分类问题不同,我们在分类问题当中有一个明确的损失函数用来优化。在我们使用梯度下降法的时候,还可以将梯度前的学习率设置得稍稍大一些,从而加快收敛的速度。但是聚类问题不同,尤其是Kmeans算法,我们的依次迭代,坐标变换的值是通过求平均坐标也就是质心的坐标得到的。除非我们修改迭代的逻辑,否则没办法加快迭代。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">我们从算法运作的思路出发的确会得到这个结论,这个结论也是没问题的,但是有问题的是收敛的速度除了取决于每次迭代的变化率之外,还有另外一个重要的指标。就是<strong style="font-weight: bold; color: rgb(71, 193, 168);">迭代起始的位置</strong>。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">也就是说我们是从怎样的情况开始收敛的,显然如果我们的初始状态离最终的收敛状态越近,那么收敛需要的迭代次数就越少,所以我们这个优化算法的目标就是想办法找到一个足够接近收敛结果的起始状态。这个思路应该也不难想通,但是这当中藏着一个巨大的疑问,我们在训练的时候并不知道收敛的状态是什么,又怎么能判断起始状态距离收敛结果的远近呢?</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">显然直接走是走不通的,我们需要迂回一下。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">我们来分析一下,其实可以得到很多结论。首先,如果我们<strong style="font-weight: bold; color: rgb(71, 193, 168);">随机选择K个样本点作为起始的簇中心效果比随机K个坐标点更好</strong>。原因也很简单,因为我们随机坐标对应的是在最大和最小值框成的矩形面积当中选择K个点,而我们从样本当中选K个点的范围则要小得多。我们可以单纯从面积的占比就可以看得出来。由于样本具有聚集性,我们在样本当中选择起始状态,选到接近类簇的可能性要比随机选大得多。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">但是还有一个小问题,比如说在上面的例子当中类簇是3,我们随机选择3个样本作为起始状态。但是问题来了,如果我们刚好选的3个点在一个类簇当中怎么办,那样到收敛状态不也需要很久吗?</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">这个问题的确是存在的,我们要避免选到同一个簇中点的情况。但是由于我们并不知道样本的分布情况,怎么来判断呢?</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">这个时候需要用到聚类的另一个性质,我们再来观察一下上面的图:</p> <figure data-tool="mdnice编辑器" style="margin: 0; margin-top: 10px; margin-bottom: 10px;"><img src="https://user-gold-cdn.xitu.io/2020/3/25/1710f1d25808e9a4?w=516&h=303&f=jpeg&s=20022" alt style="display: block; margin: 0 auto; width: auto; max-width: 100%;"></figure> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">我们可以发现,<strong style="font-weight: bold; color: rgb(71, 193, 168);">簇是有向心性的</strong>。也就是说在同一个簇附近的点都会被纳入这个簇的范围内,反过来说就是两个离得远的点属于不同簇的可能性比离得近的大。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">Kmeans++的思路正是基于上面的这两点,我们将目前已经想到的洞见整理一下,就可以得到算法原理了。</p> <h3 data-tool="mdnice编辑器" style="font-weight: bold; font-size: 20px; color: rgb(89,89,89); margin-top: 100px; margin-bottom: 50px;"><span class="prefix" style="display: none;"></span><span class="content">算法原理</span><span class="suffix" style="display: none;"></span></h3> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">首先,其实的簇中心是我们通过在样本当中随机得到的。不过我们并不是一次性随机K个,而是只随机1个。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">接着,我们要从生下的n-1个点当中再随机出一个点来做下一个簇中心。但是我们的随机不是盲目的,我们希望设计一个机制,<strong style="font-weight: bold; color: rgb(71, 193, 168);">使得距离所有簇中心越远的点被选中的概率越大,离得越近被随机到的概率越小</strong>。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">我们重复上述的过程,直到一共选出了K个簇中心为止。</p> <h3 data-tool="mdnice编辑器" style="font-weight: bold; font-size: 20px; color: rgb(89,89,89); margin-top: 100px; margin-bottom: 50px;"><span class="prefix" style="display: none;"></span><span class="content">轮盘法</span><span class="suffix" style="display: none;"></span></h3> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">我们来看一下如何根据权重来确定概率,实现这点的算法有很多,其中比较简单的是<strong style="font-weight: bold; color: rgb(71, 193, 168);">轮盘法</strong>。这个算法应该源于赌博或者是抽奖,原理也非常相似。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">我们或多或少都玩过超市或者是其他场景下的转盘抽奖,在抽奖当中有一个指针一直保持不动。我们转动转盘,当转盘停下的时候,指针所指向的位置就是抽奖的结果。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">我们都知道命中结果的概率和轮盘上对应的面积有关,面积越大抽中的概率也就越大,否则抽中的概率越小。</p> <figure data-tool="mdnice编辑器" style="margin: 0; margin-top: 10px; margin-bottom: 10px;"><img src="https://user-gold-cdn.xitu.io/2020/3/25/1710f1d29da4abd8?w=236&h=214&f=jpeg&s=6972" alt style="display: block; margin: 0 auto; width: auto; max-width: 100%;"></figure> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">我们用公式表示一下,对于每一个点被选中的概率是:</p> <span class="span-block-equation" style="cursor:pointer" data-tool="mdnice编辑器"><figure style="margin: 0; margin-top: 10px; margin-bottom: 10px;"><img class="equation" src="https://juejin.im/equation?tex=P(x_i)=\frac{f(x_i)}{\sum_{j=1}^nf(x_j)} " alt style="display: block; margin: 0 auto; width: auto; max-width: 100%;"></figure></span> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">其中<span class="span-inline-equation" style="cursor:pointer"><span><img style="margin: 0 auto; width: auto; max-width: 100%; display: inline;" class="equation" src="https://juejin.im/equation?tex=f(x_i)" alt></span></span>是每个点到所有类簇的最短距离,<span class="span-inline-equation" style="cursor:pointer"><span><img style="margin: 0 auto; width: auto; max-width: 100%; display: inline;" class="equation" src="https://juejin.im/equation?tex=P(x_i)" alt></span></span>表示点<span class="span-inline-equation" style="cursor:pointer"><span><img style="margin: 0 auto; width: auto; max-width: 100%; display: inline;" class="equation" src="https://juejin.im/equation?tex=x_i" alt></span></span>被选中作为类簇中心的概率。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">轮盘法其实就是一个模拟转盘抽奖的过程,只不过我们用数组模拟了转盘。我们把转盘的扇形拉平,拉成条状,原来的每个扇形就对应了一个区间。扇形的面积就对应了区间的长度,显然长度越长,抽中的概率越大。然后我们来进行抽奖,我们用区间的长度总和乘上一个0-1区间内的数。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">我们找到这个结果落在的区间,就是这次轮盘抽中的结果。这样我们就实现了控制随机每个结果的概率。</p> <figure data-tool="mdnice编辑器" style="margin: 0; margin-top: 10px; margin-bottom: 10px;"><img src="https://user-gold-cdn.xitu.io/2020/3/25/1710f1d28016a871?w=401&h=88&f=png&s=4654" alt style="display: block; margin: 0 auto; width: auto; max-width: 100%;"></figure> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">在上面这张图当中,我们随机出来的值是0.68,然后我们每一次减去区间长度,最后落到的区间,就是我们随机得到的结果。</p> <h3 data-tool="mdnice编辑器" style="font-weight: bold; font-size: 20px; color: rgb(89,89,89); margin-top: 100px; margin-bottom: 50px;"><span class="prefix" style="display: none;"></span><span class="content">总结</span><span class="suffix" style="display: none;"></span></h3> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">明白了轮盘算法之后,整个Kmeans++的思路已经是一览无余了。也就是说我们把抽取类簇中心类比成了轮盘抽奖,我们利用轮盘抽取K个样本来作为初始的类簇中心。从而尽可能地减少迭代次数,逼近最终的结果。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">那么,这样的方法究竟有没有效果呢?</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">同样,我们通过实验来证明,首先我们来写出代码。我们需要一个辅助函数用来<strong style="font-weight: bold; color: rgb(71, 193, 168);">计算某个样本和已经选好的簇中心之间的最小距离</strong>,我们要用这个距离来做轮盘算法。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">这个函数很简单,只是计算距离,取最小值而已:</p> <pre class="custom" data-tool="mdnice编辑器" style="margin-top: 10px; margin-bottom: 10px;"><code class="hljs" style="overflow-x: auto; padding: 16px; color: #333; background: #f8f8f8; display: -webkit-box; font-family: Operator Mono, Consolas, Monaco, Menlo, monospace; border-radius: 0px; font-size: 12px; -webkit-overflow-scrolling: touch;"><span class="hljs-function" style="line-height: 26px;"><span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">def</span> <span class="hljs-title" style="color: #900; font-weight: bold; line-height: 26px;">get_cloest_dist</span><span class="hljs-params" style="line-height: 26px;">(point, centroids)</span>:</span><br> <span class="hljs-comment" style="color: #998; font-style: italic; line-height: 26px;"># 首先赋值成无穷大,依次递减</span><br> min_dist = math.inf<br> <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">for</span> centroid <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">in</span> centroids:<br> dist = calculateDistance(point, centroid)<br> <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">if</span> dist &lt; min_dist:<br> min_dist = dist<br> <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">return</span> min_dist<br></code></pre> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">接着就是用轮盘法选出K个中心,首先我们先随机选一个,然后再根据距离这个中心的举例用轮盘法选下一个,依次类推,直到选满K个中心为止。</p> <pre class="custom" data-tool="mdnice编辑器" style="margin-top: 10px; margin-bottom: 10px;"><code class="hljs" style="overflow-x: auto; padding: 16px; color: #333; background: #f8f8f8; display: -webkit-box; font-family: Operator Mono, Consolas, Monaco, Menlo, monospace; border-radius: 0px; font-size: 12px; -webkit-overflow-scrolling: touch;"><span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">import</span> math<br><span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">import</span> random<br><br><span class="hljs-function" style="line-height: 26px;"><span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">def</span> <span class="hljs-title" style="color: #900; font-weight: bold; line-height: 26px;">kmeans_plus</span><span class="hljs-params" style="line-height: 26px;">(dataset, k)</span>:</span><br> clusters = []<br> n = dataset.shape[<span class="hljs-number" style="color: #008080; line-height: 26px;">0</span>]<br> <span class="hljs-comment" style="color: #998; font-style: italic; line-height: 26px;"># 首先先选出一个中心点</span><br> rdx = np.random.choice(range(n), <span class="hljs-number" style="color: #008080; line-height: 26px;">1</span>)<br> <span class="hljs-comment" style="color: #998; font-style: italic; line-height: 26px;"># np.squeeze去除多余的括号</span><br> clusters.append(np.squeeze(dataset[rdx]).tolist())<br> d = [<span class="hljs-number" style="color: #008080; line-height: 26px;">0</span> <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">for</span> _ <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">in</span> range(len(dataset))]<br> <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">for</span> _ <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">in</span> range(<span class="hljs-number" style="color: #008080; line-height: 26px;">1</span>, k):<br> tot = <span class="hljs-number" style="color: #008080; line-height: 26px;">0</span><br> <span class="hljs-comment" style="color: #998; font-style: italic; line-height: 26px;"># 计算当前样本到已有簇中心的最小距离</span><br> <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">for</span> i, point <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">in</span> enumerate(dataset):<br> d[i] = get_cloest_dist(point, clusters)<br> tot += d[i]<br> <span class="hljs-comment" style="color: #998; font-style: italic; line-height: 26px;"># random.random()返回一个0-1之间的小数</span><br> <span class="hljs-comment" style="color: #998; font-style: italic; line-height: 26px;"># 总数乘上它就表示我们随机转了轮盘</span><br> tot *= random.random()<br> <span class="hljs-comment" style="color: #998; font-style: italic; line-height: 26px;"># 轮盘法选择下一个簇中心</span><br> <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">for</span> i, di <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">in</span> enumerate(d):<br> tot -= di<br> <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">if</span> tot &gt; <span class="hljs-number" style="color: #008080; line-height: 26px;">0</span>:<br> <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">continue</span><br> clusters.append(np.squeeze(dataset[i]).tolist())<br> <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">break</span><br> <span class="hljs-keyword" style="color: #333; font-weight: bold; line-height: 26px;">return</span> np.mat(clusters)<br></code></pre> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">最后,我们把图画出来看下效果:</p> <figure data-tool="mdnice编辑器" style="margin: 0; margin-top: 10px; margin-bottom: 10px;"><img src="https://user-gold-cdn.xitu.io/2020/3/25/1710f1d2760124b7?w=521&h=254&f=jpeg&s=14822" alt style="display: block; margin: 0 auto; width: auto; max-width: 100%;"></figure> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">上图当中白色的点表示最后收敛的位置,红色的X表示我们用Kmeans++计算得到的起始位置,可以发现距离最终的结果已经非常接近了。显然,我们只需要很少几次迭代就可以达到收敛状态。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">当然<strong style="font-weight: bold; color: rgb(71, 193, 168);">Kmeans++本身也具有随机性</strong>,并不一定每一次随机得到的起始点都能有这么好的效果,但是通过策略,我们可以保证即使出现最坏的情况也不会太坏。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">在实际的场景当中,如果我们真的需要对大规模的数据应用Kmeans算法,我们往往会将多种优化策略结合在一起用,并且多次计算取平均,从而保证在比较短的时间内得到一个足够好的结果。这也是机器学习领域很多算法优化的精髓,即不再追求最优解,而只要一个足够好的解。很多时候,<strong style="font-weight: bold; color: rgb(71, 193, 168);">在结果上一点小小的退让,可以将算法效率提升很多</strong>。</p> <p data-tool="mdnice编辑器" style="font-size: 16px; padding-top: 8px; padding-bottom: 8px; margin: 0; line-height: 26px; color: rgb(89,89,89);">今天关于Kmeans的优化内容就到这些,如果觉得有所收获,请顺手点个<strong style="font-weight: bold; color: rgb(71, 193, 168);">关注或者转发</strong>吧,你们的举手之劳对我来说很重要。</p> </section>

原文出处:https://www.cnblogs.com/techflow/p/12563867.html

展开阅读全文
打赏
0
0 收藏
分享
加载中
更多评论
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部