机器学习中的训练数据不平衡问题

2018/07/07 17:34
阅读数 40

不平衡类的问题是什么?

在一个分类问题中,当你想要预测一个或多个类的所有类中的样本数量极少时,可能会遇到数据中类不平衡的问题。

例子

  • 欺诈预测(欺诈的数量将大大低于真正的交易)

  • 自然灾害预测(不好的事件会比好的情况低很多)

  • 在图像分类中识别恶性肿瘤(在一个训练样本中,肿瘤的图像比没有肿瘤的图像要小得多)

为什么这是个问题?

由于两个主要原因:

  1. 对于实时不平衡的类,我们没有得到优化的结果,因为模型/算法从来没有充分地查看基础类

  2. 它会产生一个验证或测试样本的问题,因为很难在类中进行表示,因为少数类的观察次数极少

解决这个问题的不同方法有哪些?

有三种主要方法建议各有利弊:

  1. 欠采样 - 随机删除具有足够观察值的类,以便两个类的比较比率在我们的数据中显着。虽然这种方法非常简单,但很有可能我们删除的数据可能包含有关重要信息预测类。

  2. 过采样 -对于不平衡类别,随机增加观察值的数量,这些观察值只是现有样本的副本。理想情况下,这给我们足够数量的样本进行播放。过采样可能导致过度拟合训练数据

  3. 合成采样(SMOTE) -该技术要求综合制造与使用最近邻居分类的现有类似的不平衡类别的观测值。问题是当观察次数是极少数类时要做什么。例如,我们可能只有一幅我们想要使用图像分类算法识别的稀有物种的图片

尽管每种方法都有各自的优点,但没有什么特定的启发式方法可以使用。

图像分类中的不平衡类

在本节中,我们将找到一个图像分类问题,其中存在不平衡类问题,然后我们将使用一种简单有效的技术来解决它。

问题 - 我们在kaggle上找到了“驼背鲸识别挑战(Humpback Whale Identification Challenge)”,我们预计这会对解决不平衡类问题提出挑战(理想情况下,被分类的鲸鱼的数量将少于未分类的鲸类,因此,我们将有更少的图像数量)

来自kaggle:“在这场比赛中,你面临挑战,要建立一个算法来识别图像中的鲸鱼种类。您将分析Happy Whale的超过25,000张图像的数据库,这些数据来自研究机构和公共贡献者。通过贡献,您将有助于为全球海洋哺乳动物种群动态开启丰富的理解领域。

让我们开始看数据

由于这是一个多标签图像分类问题,我首先想要检查数据是如何在类中分布的。

机器学习中的训练数据不平衡问题

上面的图表表明,在4251个训练图像中,每个类只有一个图像,还有一些图像具有2-5个图像。现在,这是一个严重的不平衡类问题。我们不能期望DL模型每个类仅使用一个图像进行训练(虽然有些算法可能只是做一些例子,但我们现在忽略了这一点)。这也会产生一个问题,如何在训练和验证样本之间创建一个分界线。您理想情况下希望每个类都在训练样本和验证样本中表示。

我们现在应该做什么?

我们考虑了两个特别的选项:

  • option1在训练样本上进行了严格的数据增量(我们可以这样做,但由于我们只需要针对特定类的数据增量,这可能无法完全解决我们的目的)。因此,我选择了看起来很简单的选项2。

  • option2-类似于我上面提到的过采样选项。我只是使用不同的图像增强技术将不平衡类的图像复制到训练数据中15次。

在我们开始使用选项2之前,可以从训练样本中查看少量图像。

机器学习中的训练数据不平衡问题

这些图像是特定于鲸鱼的fluke,因此,识别将可能是特定于图像的方向。

我也注意到在数据中有很多的图像是特定的B&W或者仅仅是R/B/G通道。

基于这些观察,我决定写下面的代码来做一些图像的小改变,这些图像来自于训练样本ans的不平衡的类:

import os

    from PIL import Image

    from PIL import ImageFilter

    filelist = train['Image'].loc[(train['cnt_freq']<10)].tolist()

    for count in range(0,2):

    for imagefile in filelist:

    os.chdir('/home/paperspace/fastai/courses/dl1/data/humpback/train')

    im=Image.open(imagefile)

    im=im.convert("RGB")

    r,g,b=im.split()

    r=r.convert("RGB")

    g=g.convert("RGB")

    b=b.convert("RGB")

    im_blur=im.filter(ImageFilter.GaussianBlur)

    im_unsharp=im.filter(ImageFilter.UnsharpMask)

    os.chdir('/home/paperspace/fastai/courses/dl1/data/humpback/copy')

    r.save(str(count)+'r_'+imagefile)

    g.save(str(count)+'g_'+imagefile)

    b.save(str(count)+'b_'+imagefile)

    im_blur.save(str(count)+'bl_'+imagefile)

    im_unsharp.save(str(count)+'un_'+imagefile)

以上代码块对不平衡类中的每个图像(频率小于10)都进行如下处理:

  • 将每个图像的增强副本保存为R / B&G

  • 保存每个图像的增强副本,这是blury

  • 保存每张图像不清晰的增强副本

在上面的代码中可以看到,我们使用pillow (一个python图像库)来严格执行此练习

现在我们已经为所有不平衡的类分配了至少10个样本。我们继续进行了培训。

图像增强,我们就这么简单。我们只是想确保我们的模型能够得到关于鲸鱼的fluke的详细的观点。为此,我们将zoom纳入了图像增强。

机器学习中的训练数据不平衡问题

Learning rate finder -我们决定将学习率定为0.01

机器学习中的训练数据不平衡问题

我们使用Resnet50进行了很少的迭代(第一次冻结和解冻)。发现冻结模型对于这个问题陈述也非常有用,因为imagenet中有鲸鱼图像。

机器学习中的训练数据不平衡问题

如何看待测试数据的?

最后,我们在kaggle排行榜上获得了真相。解决方案在本次比赛中提出了34的平均精确度,平均精度为0.41928。

展开阅读全文
打赏
0
0 收藏
分享
加载中
更多评论
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部