半监督学习

原创
07/02 08:57
阅读数 1.7K

半监督学习指的是结合了少量的有标记数据和大量无标记数据来完成训练的过程。

在某些特定领域,大量有标记的数据很少也很难标注。

比方说,我们现在有一个公开数据集,它全部都是有标注的。此时我们可以使用有监督的学习来看一下结果,再使用10%的有标注的数据集结合剩下90%的未标注的数据来使用半监督学习的方法,我们希望半监督学习的方法也能达到有监督学习的水平。

半监督学习的应用

  1. 视频理解,
  2. 自动驾驶
  3. 医疗影像分割
  4. 心脏信号分析

半监督前提假设

  • 连续性假设(Continuity Assumption):

我们用一个分类问题来举例,当我们的Input是比较接近的时候,比如说进行猫狗分类,两张猫的图片是比较接近的时候,此时output(后验概率矩阵)也必须是比较接近的。

比如说,x1和x2比较接近,x1的后验概率是0.9和0.1,明显它是分成第一个类0.9的。x2有两组输出,一个是0.85和0.15,另一个是0.55和0.45。那么我们可以看到这两组的输出虽然都是把类别分到了第一个,但是第二组输出是不满足连续性假设的,因为它和距离比较大,差的比较多。

  • 聚类假设(Cluster Assumption)

聚类假设指的是类类要内聚,类间要分开。就是说同一类的东西要非常相似,比较靠近,接近于一点。不同的类别要尽可能的分开。所以不能有模糊不清的图片,如

  • 流行假设(Manifold Assumption)
  1. 所有数据点都可以被低维流行表达。
  2. 相同流行上的数据点,标签一样。

这里可以理解成降维,很多高维的数据的一些维度是不起作用的,它们的特点集中在一些低维度上。

半监督学习数学定义

上表是一个学术论文上,字符所代表的含义,x代表的是输入;y代表的是输出,它要么是个分类输出,要么是个回归输出;代表有标签的数据集;代表无标签的数据集;X就是整个的数据集,包含有标签的和无标签的;L指损失函数;G是生成器,半监督学习可以用到生成式模型;D是判别器;C是分类器;H是熵,一般指交叉熵;E是期望;R是正则项,半监督学习中一般指一致性正则,当然半监督学习也可以使用传统的L1和L2正则;是指标签。

半监督学习最核心的其实就是它的损失函数,它一般包含三个部分,第一部分就是有监督的loss(supervised loss),第二部分就是无监督的loss(unsupervised loss)以及第三部分正则项(regularization)。因为半监督学习有少量的有标签的数据,那么第一部分就是这些有标签数据的loss;当然还有大量的未标注的数据,第二部分就是这些未标注数据的loss;第三部分可以用L1、L2正则,也可以是一致性正则。

第一部分的loss跟之前是一样的,一般是交叉熵损失函数,最主要的就是设计后面两部分的损失函数。

半监督学习实施方法

半监督学习模型可以分为五大方法,第一个是生成式模型,第二个是一致性损失正则,第三个是图神经网络,第四个是伪标签的方法,第五个是混合方法。现在用的最多的是混合方法,它可以结合前面四种方法的优点。

  • Generative Based:基于生成式网络

1、重用判别器(Re-using Discriminator)

在我们使用GAN的时候,我们知道,鉴别器充当的是二分类器的功能,对输入的真实的图片或者生成的图片来判定是真是假。重用鉴别器在半监督学习中是一个K分类的分类器,它不仅仅是对有标签的数据(x,y)进行分类,还有生成的数据(G(z))和未标注的数据x进行分类。通过这三块的损失来构建我们的K类别的分类器。这样就达到了我们的目的,联合了未标注的数据和有标签的数据。

2、用于正则化分类器的生成样本(Generated samples to regularize a classifier)

这里的鉴别器D依然是一个二分类器,生成器G生成数据的时候的输入包含了未标注数据x,还包含了某一分布的随机初始矩阵z,来共同生成,再由生成,生成的公式如下

这里的m是一个二值化的掩膜,即一个和x一样大的矩阵,它的值只有0和1。0乘以x中的像素点直接置为0,而1会保留x中的像素点的值。最后联合x和一同送入鉴别器D中来判别它们是否是一致的。我们希望我们的判别结果是一致的,这就意味着能驱动判别器D来识别到图片的某一块的特征。一旦该模型训练完备之后,就可以单独将鉴别器提取出来用在别的分类器中去。也可以用于构建别的loss设计的一部分,相当于一个表征或特征抽取器

3、推理模型(inference model)

这是一个统称,不是指具体某一个模型的名字,有很多。

它跟第一种重用判别器很像,多了一个C(类别)。前面的步骤是相同的,只是在最后在判别器D这里多了一个类别,不是K个类别而是K+1个类别。多出来的这个类别就是生成器生成的G(z)的类别,它需要跟真实的K个类别的某一个类别要接近,这就是它的目的。

4、生成数据(Generate Data)

生成网络可以用在数据增强,生成更多的数据来。因为我们未标记的数据有很多,那我们干脆直接训练一个生成器,让它造更多数据出来。

  • Consistency Regularization:一致性正则

这种方式是半监督学习的核心

设计思路:

这里θ是指模型参数,也就是模型。x是未标注的数据,指的是标签。

上图中,未标注数据x经过两种不同的随机数据增强Aug1和Aug2,也就是随机的翻转,旋转,平移,光照等等。然后送入模型中,让模型进行识别,会得到一个后验概率或者特征,我们希望输出的两个值是接近的。因为我们的输入是接近的,虽然x经过两种不同的扰动,但输出应该要接近。用公式表示为

这里的ζ指的是随机数据增强。ζ1和ζ2是两种不同的随机数据增强。

每个训练的epoch,会被前向推理两次,这两次虽然输入经过不同的随机增广,但是输出应该具有一致性。

其实这种扰动不单单是可以用随机数据增强,还可以使用很多的方法。

上表中是半监督学习经常刷榜的模型,它们的核心都在一致性正则上。比如说第三个,对于两种扰动,第二个扰动加了EMA(指数移动平均);第四个是在第二种扰动中对模型参数加了EMA;最后一个对于同样的模型,不增广,而是直接在模型上加了扰动。

  • Pseudo-label:伪标签

半监督学习的大量数据是没有标签的,那么我们使用模型来预测一个标签,然后再送进模型训练。

伪标签的损失函数如下

其中第一部分是有标记数据的损失,是真实的标签,是有标记数据的前向推理值。第二部分是未标注数据的损失,是伪标签,也就是预测出来的标签,是未标记数据的前向推理值。伪标签看似是一个简单的思路,但其实涉及到的方法也很多,它可能跟一致性正则一样,在结构上做设计,或者在训练的流程上做设计以及伪标签预测的方法上做设计。

伪标签有一个弊端

  1. 伪标签选择不太容易,在模型训练初期,可能是一个不太好的模型,预测出来的标签极有可能是不正确的。如果此时再将预测出来的标签送进模型训练可能会引起进一步的崩溃。
  2. 在伪标签损失函数中第二部分有一个,它的意思代表伪标签损失值占整个损失函数多大的比重。而这个的权重值也是很难确定的。如果太小,则未标注数据就失去了作用;太大,如果预测出来的伪标签是不正确的,会导致损失结果难以收敛。

 MixMatch半监督学习

MixMatch结合了之前说的几种方法,用了单个loss,将这几种方式进行合并,如一致性正则,最小化熵,传统正则。它有一个很重要的方法叫MixUp。

它取的有标记数据和无标记数据的BatchSize是一样大的,不过无标记数据会经过K个增强。首先会对有标记数据进行增强,再对无标记数据进行K次增强,再将增强后的无标记数据送入模型,每一种增强的无标记数据会预测一个结果(softmax的结果),将结果取均值

然后再锐化。

上式中p为无标记数据输出结果的后验概率,T为温度项,当T接近于0的时候,分类结果会出现one-hot的情况,当T趋向于+∞的时候,分类结果会出现无差别,都一样的情况。经过这么处理之后会使得后验概率的直方图会很尖锐(其中一个分类项会特别突出)

之后会得到有标记的数据,和无标记数据以及猜测出的标记,然后将这两种数据给拼接(concat)起来,组合成一个大的数据,再随机打乱拼接后的数据与有标记数据和带猜测标签的无标记数据进行混合MixUp。

它先从Beta抽样中,抽样一个值出来,然后从该值与1-该值中取最大值。x1是有标记数据和无标记数据的拼接(注意是未随机打乱的),x2是这两种数据拼接后再随机打乱后的数据。p1是有标记数据的标签和无标记数据的猜测标签的拼接(未随机打乱),p2是这两种数据拼接后再随机打乱的标签(可能是真实标签也可能是猜测标签)。这样就得到了新的数据和标签。

最后再送入模型,求损失函数。

它的整体流程的伪代码如下

代码实现

超参数设置

import torch

# ################################################################
#                             HyperParameters
# ################################################################
# semi-supervised learning:
#     1. model structure
#     2. hype setting are important!
class Hyperparameters:
    # ################################################################
    #                             Data
    # ################################################################
    device = 'cuda' if torch.cuda.is_available() else 'cpu'  # cuda for training, cpu/cuda for inference
    classes_num = 10  # 分类数
    n_labeled = 250  # 已标记数据总数
    seed = 1234

    # ################################################################
    #                             Model
    # ################################################################
    T = 0.5  # 锐化温度项(sharpen temperature)
    K = 2  # 数据增强次数
    alpha = 0.75  # 伪标签损失权值
    lambda_u = 75.  # 一致性损失权值
    # ################################################################
    #                             Exp
    # ################################################################
    batch_size = 8
    init_lr = 0.002
    epochs = 1000
    verbose_step = 300
    save_step = 300

HP = Hyperparameters()

数据集,这里使用的是cifar10数据集

import numpy as np
from torchvision import transforms
import torchvision
import torch


class TransformTwice:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, inp):
        out1 = self.transform(inp)
        out2 = self.transform(inp)
        return out1, out2


def get_cifar10(root, n_labeled,
                 transform_train=None, transform_val=None,
                 download=True):
    # 获取cifar10数据集
    base_dataset = torchvision.datasets.CIFAR10(root, train=True, download=download)
    # 将该数据集拆分成有标记的训练数据集,无标记的训练数据集和验证集
    train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.targets, int(n_labeled / 10))
    # datasset->dataload
    train_labeled_dataset = CIFAR10_labeled(root, train_labeled_idxs, train=True, transform=transform_train)
    train_unlabeled_dataset = CIFAR10_unlabeled(root, train_unlabeled_idxs, train=True, transform=TransformTwice(transform_train))
    val_dataset = CIFAR10_labeled(root, val_idxs, train=True, transform=transform_val, download=True)
    test_dataset = CIFAR10_labeled(root, train=False, transform=transform_val, download=True)

    print(f"#Labeled: {len(train_labeled_idxs)} #Unlabeled: {len(train_unlabeled_idxs)} #Val: {len(val_idxs)}")
    return train_labeled_dataset, train_unlabeled_dataset, val_dataset, test_dataset
    

def train_val_split(labels, n_labeled_per_class):
    labels = np.array(labels)
    train_labeled_idxs = []
    train_unlabeled_idxs = []
    val_idxs = []

    for i in range(10):
        idxs = np.where(labels == i)[0]
        np.random.shuffle(idxs)
        train_labeled_idxs.extend(idxs[:n_labeled_per_class])
        train_unlabeled_idxs.extend(idxs[n_labeled_per_class:-500])
        val_idxs.extend(idxs[-500:])
    np.random.shuffle(train_labeled_idxs)
    np.random.shuffle(train_unlabeled_idxs)
    np.random.shuffle(val_idxs)

    return train_labeled_idxs, train_unlabeled_idxs, val_idxs


cifar10_mean = (0.4914, 0.4822, 0.4465)  # equals np.mean(train_set.train_data, axis=(0,1,2))/255
cifar10_std = (0.2471, 0.2435, 0.2616)  # equals np.std(train_set.train_data, axis=(0,1,2))/255


def normalise(x, mean=cifar10_mean, std=cifar10_std):
    x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)]
    x -= mean*255
    x *= 1.0/(255*std)
    return x


def transpose(x, source='NHWC', target='NCHW'):
    return x.transpose([source.index(d) for d in target])


def pad(x, border=4):
    return np.pad(x, [(0, 0), (border, border), (border, border)], mode='reflect')


class RandomPadandCrop(object):
    """Crop randomly the image.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, x):
        x = pad(x, 4)

        h, w = x.shape[1:]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        x = x[:, top: top + new_h, left: left + new_w]

        return x


class RandomFlip(object):
    """Flip randomly the image.
    """
    def __call__(self, x):
        if np.random.rand() < 0.5:
            x = x[:, :, ::-1]

        return x.copy()


class GaussianNoise(object):
    """Add gaussian noise to the image.
    """
    def __call__(self, x):
        c, h, w = x.shape
        x += np.random.randn(c, h, w) * 0.15
        return x


class ToTensor(object):
    """Transform the image to tensor.
    """
    def __call__(self, x):
        x = torch.from_numpy(x)
        return x


class CIFAR10_labeled(torchvision.datasets.CIFAR10):

    def __init__(self, root, indexs=None, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(CIFAR10_labeled, self).__init__(root, train=train,
                 transform=transform, target_transform=target_transform,
                 download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]
        self.data = transpose(normalise(self.data))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target
    

class CIFAR10_unlabeled(CIFAR10_labeled):

    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(CIFAR10_unlabeled, self).__init__(root, indexs, train=train,
                 transform=transform, target_transform=target_transform,
                 download=download)
        self.targets = np.array([-1 for i in range(len(self.targets))])


transform_train = transforms.Compose([
        RandomPadandCrop(32),
        RandomFlip(),
        ToTensor(),
    ])

transform_val = transforms.Compose([
    ToTensor(),
])

现在增加一些工具方法,第一个是EMA(Exponential Moving Average),指数移动平均,是时间序列分析中常用到的一种类型平均值。简单来说,EMA就是一个加权平均值。它的特别之处在于:

  1. 随着时间流逝,旧的观察值的权重将会呈现指数衰减(Exponential Decay)
  2.  [公式] 代表距离当前时刻 [公式] 之前的观察值的权重,那么 [公式]
  3. 其中 [公式] 掌控着指数衰减的程度, [公式] 越大,权重随时间衰减得越快。

那么顾名思义,EMA说到底就是一个加权平均值,可以根据加权平均值的定义来写出来

[公式]

EMA的迭代定义

  1. 等间距的时间序列 [公式] (等间距:即两个相邻样本之间的时间间隔是不变的)
  2. 衰变权重是 [公式]  [公式] 越大,之前的观察样本获得的权重越小、衰变越多)

那么,EMA(由 [公式] 表示在时间 [公式]的EMA)的迭代公式的定义为:

[公式]

其中 [公式] 是EMA的初始值。

 [公式] 的迭代公式出发,持续迭代进去 [公式] ,我们得到

[公式]

假设我们有无穷多个观察值,那么可以把上式进一步展开为

[公式]

其中我们用到了泰勒序列 [公式] 。令 [公式] ,这个等式就与上文中的加权平均定义一模一样了。也就观察值无穷多的时候,二者收敛到同等的值。

import torch
import numpy as np


class WeightEMA:
    # 指数移动平均,使训练流程更为稳固

    def __init__(self, model, ema_model, alpha=0.999):
        # 模型
        self.model = model
        # 影子模型,用于计算指数移动平均参数的模型
        self.ema_model = ema_model
        # 衰变权重
        self.alpha = alpha
        # 模型参数
        self.params = list(model.state_dict().values())
        # 影子模型参数
        self.ema_params = list(ema_model.state_dict().values())
        # 模型权重值
        self.weight_decacy = 0.0004
        # 同步模型参数
        for param, ema_param in zip(self.params, self.ema_params):
            param.data.copy_(ema_param)

    def step(self):
        # 每一步迭代
        for param, ema_param in zip(self.params, self.ema_params):
            # 参数类型必须为浮点型
            if ema_param.dtype == torch.float32:  # model weights only!
                # 获取当前影子模型的参数(带权重)
                ema_param.mul_(self.alpha)
                # 当前影子模型的参数加上上一步的模型参数(带权重)
                # EMA过程结束
                ema_param.add_(param * (1 - self.alpha))
                # 对整个模型参数进行缩放
                param.mul_((1 - self.weight_decacy))

第二个是求总体损失函数的未标记数据损失函数的权重,这个权重我们之前说很难确定

这里使用的是稳步增长的方式来获取这个权重

def lambda_rampup(step, MAX_STEP=1e6, max_v=75):
    """
    求超参数损失函数未标记数据损失函数权重
    :param step: 训练步数
    :param MAX_STEP: 最大训练步数
    :param max_v: 权重最大值
    :return: 当前权重值
    """
    return np.clip(a=max_v * (step / MAX_STEP), a_min=0., a_max=max_v)

第三个是标签猜测

def label_guessing(out_u, out_u2):
    """
    经过两次数据增强K=2(default)进行标签猜测
    :param out_u: [N, 10], 第一次数据增强的结果
    :param out_u2: [N, 10],第二次数据增强的结果
    :return: average label guessing, [N, 10]
    [[0.22, 0.32......], => sum = 1.
    [0.01, 0.3, 0.03...], => sum = 1.
    ....]
    """
    # 对两次数据增强的结果的后验概率取均值,即除以2
    q = (torch.softmax(out_u, dim=-1) + torch.softmax(out_u2, dim=-1)) / 2.
    # algorithm 1, line 7
    return q

第四个是锐化

def sharpen(p, T):
    """
    锐化,使输出更加尖锐,整体计算方式为对标签猜测后的后验概率求1/T次方再求和
    然后再使用标签猜测后的后验概率求1/T次方除以求和的值
    :param p: 后验概率
    [[0.22, 0.32......], => sum = 1.
    [0.01, 0.3, 0.03...], => sum = 1.
    ....]
    :param T: 温度项
    :return: sharpened result
    """
    p_power = torch.pow(p, 1./T)
    return p_power / torch.sum(p_power, dim=-1, keepdim=True)  # [N , 10]

第五个是MixUp

def mixup(x, u, u2, trg_x, out_u, out_u2, alpha=0.75):
    """
    mixup: 将打乱后的有标签数据和无标签数据的拼接与有标签数据和无标签数据进行混合
    :param x: 有标记数据, [N, 3, H, W]
    :param u: 第一种数据增强的无标记数据, [N, 3, H, W]
    :param u2: 第二种数据增强的无标记数据, [N, 3, H, W]
    :param trg_x: 有标记数据标签,[N, ]=[0, 7, 8...]
    :param out_u: 第一种数据增强的猜测标签
    :param out_u2: 第二种数据增强的猜测标签
    :param alpha: Beta抽样的超参
    :return: mixuped result: x: [3*N, 3, H, W], y: [3*N, 10]
    """
    batch_size = x.size(0)  # batch size = HP.batch_size
    n_classes = out_u.size(1)  # classes number: 10
    device = x.device
    # [0.1,0.3.0.01.....] dim=10
    # [0., 0.,0., 0.,0., 0.,0., 0.,1., 0.,] dim=10
    # 将有标记数据的标签转成one-hot形式
    trg_x_onehot = torch.zeros(size=(batch_size, n_classes)).float().to(device)
    # [0, 0., 0., 0., 0., 0, 0., 0., 0., 0.,]
    # trg_x [7]
    # [0, 0., 0., 0., 0., 0, 0., 1., 0., 0.,]
    trg_x_onehot.scatter_(1, trg_x.view(-1, 1), 1.)

    # 拼接有标记数据和无标记数据
    x_cat = torch.cat([x, u, u2], dim=0)
    trg_cat = torch.cat([trg_x_onehot, out_u, out_u2], dim=0)
    # 获取这两种数据的总量
    n_item = x_cat.size(0)  # N*3
    # beta抽样
    lam = np.random.beta(alpha, alpha)  # eq. (8)
    # 获取抽样的值与1-该值的最大值
    lam_prime = max(lam, 1 - lam)         # eq. (9)
    # 将拼接后的数据随机打乱
    rand_idx = torch.randperm(n_item)   # a rand index sequence: [0,2, 1], [1, 0, 2]
    x_cat_shuffled = x_cat[rand_idx]    # x2
    trg_cat_shuffled = trg_cat[rand_idx]
    # 获取输出数据
    x_cat_mixup = lam_prime * x_cat + (1 - lam_prime) * x_cat_shuffled    # eq. (9)
    # 获取输出标签
    trg_cat_mixup = lam_prime * trg_cat + (1 - lam_prime) * trg_cat_shuffled  # eq. (10)

    return x_cat_mixup, trg_cat_mixup

第六个是模型精度

def accuracy(output, target, topk=(1, )):
    """
    topk 准确率
    :param output: 模型输出[N, 10]
    :param target: 模型标签[N, ]
    :param topk: top1,top3, top5
    :return: acc list
    """
    maxk = max(topk)  # max k, topk=(1, 3, 5)
    batch_size = target.size(0)
    # 获取topk的索引
    _, pred = output.topk(maxk, 1, True, True)
    # 转置
    pred = pred.t()  # [maxk, N]
    # 获取比较矩阵
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    # 获取每一种top的准确率
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100. / batch_size))
    return res  # [50, 85, 99]

接下来是半监督学习的模型,模型在这里其实并不重要,我们可以任选一种深度学习模型进行特征提取即可,这里选择的是WideResNet(宽残差网络),并且是Pytorch自带的迁移学习模型,这样我们就不用自己去搭建模型了。

import torch
import torchvision
from torch import nn
from config import HP


class WideResnet50_2(nn.Module):
    def __init__(self):
        super(WideResnet50_2, self).__init__()
        resnet = torchvision.models.wide_resnet50_2(pretrained=False)
        last_fc_dim = resnet.fc.in_features  # defaut imagenet, 1000
        fc = nn.Linear(in_features=last_fc_dim, out_features=HP.classes_num)
        resnet.fc = fc
        self.wideresnet4cifar10 = resnet

    def forward(self, input_x):
        return self.wideresnet4cifar10(input_x)

然后是损失函数

import torch
import torch.nn.functional as F
from torch import nn


class MixUpLoss(nn.Module):
    def __init__(self):
        super(MixUpLoss, self).__init__()

    def forward(self, output_x, trg_x, output_u, trg_u):
        """
        loss function: 损失函数
        :param output_x: 混合后的已标记数据模型输出: [N, 10]
        :param trg_x: trg_x: 混合后的已标记数据标签: [N, 10]
        :param output_u: 混合后的经过2次数据增强的未标记数据模型输出: [2*N, 10]
        :param trg_u:  混合后的经过2次数据增强的未标记数据猜测标签: [2*N, 10]
        :return:Lx, Lu
        """
        # 对混合后的已标记数据模型输出求交叉熵损失函数
        Lx = -torch.mean(torch.sum(F.log_softmax(output_x, dim=-1) * trg_x, dim=-1))  # cross-entropy, supervised loss
        # 对混合后的经过2次数据增强的未标记数据模型输出求均方误差损失函数
        Lu = F.mse_loss(output_u, trg_u)  # consistency reg
        return Lx, Lu

因为总的损失函数还有一个还没确定,所以总的损失函数会在训练部分添加。

模型训练,首先我们会做一些准备性的工作,包括对数据集dataloader的读取以及对模型的保存和评价

import os
import random
from argparse import ArgumentParser

import torch.cuda
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from model import WideResnet50_2
import dataset.cifar10 as dataset
from utils import *
from tensorboardX import SummaryWriter
from config import HP
from loss import MixUpLoss

# 随机种子
torch.manual_seed(HP.seed)
torch.cuda.manual_seed(HP.seed)
random.seed(HP.seed)
np.random.seed(HP.seed)

# 用于训练的数据增强和转化
transform_train = transforms.Compose([
    dataset.RandomPadandCrop(32),
    dataset.RandomFlip(),
    dataset.ToTensor(),
])

# 用于推理,验证,测试的数据转化
transform_val = transforms.Compose([
    dataset.ToTensor(),
])

# labeled dataloader / 2 unlabeled dataloaders / validation dataloader
train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10('./data',
                                                                                n_labeled=HP.n_labeled,
                                                                                transform_train=transform_train,
                                                                                transform_val=transform_val)
labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=HP.batch_size, shuffle=True, drop_last=True)
unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=HP.batch_size, shuffle=True, drop_last=True)
val_loader = data.DataLoader(val_set, batch_size=HP.batch_size, shuffle=False, drop_last=False)

logger = SummaryWriter('./log')


# 影子模型
def new_ema_model():
    model = WideResnet50_2()
    model = model.to(HP.device)
    for param in model.parameters():
        param.detach_()  # 避免梯度追踪
    return model


# 保存模型参数
def save_checkpoint(model_, ema_model_, epoch_, optm, checkpoint_path):
    save_dict = {
        'epoch': epoch_,
        'model_state_dict': model_.state_dict(),
        'ema_model_state_dict': ema_model_.state_dict(),
        'optimizer_state_dict': optm.state_dict(),
    }
    torch.save(save_dict, checkpoint_path)


# 评价
def evaluate(model_, val_loader_, crit):
    model_.eval()
    sum_loss = 0.
    acc1, acc5 = 0., 0.
    with torch.no_grad():
        for batch in val_loader_:
            # load eval data
            inputs_x, trg_x = batch
            inputs_x, trg_x = inputs_x.to(HP.device), trg_x.long().to(HP.device)
            out_x = model_(inputs_x)  # model inference
            top1, top5 = accuracy(out_x, trg_x, topk=(1, 5))
            acc1 += top1
            acc5 += top5
            sum_loss += crit(out_x, trg_x)  #计算loss
    loss = sum_loss / len(val_loader_)
    acc1 = acc1 / len(val_loader_)
    acc5 = acc5 / len(val_loader_)
    model_.train()
    return acc1, acc5, loss

然后是最重要的训练

def train():
    parser = ArgumentParser(description='Model Training')
    parser.add_argument(
        '--c',
        default=None,
        type=str,
        help='train from scratch or resume from checkpoint'
    )
    args = parser.parse_args()

    # 新建模型和影子模型
    model = WideResnet50_2()
    model = model.to(HP.device)
    ema_model = new_ema_model()
    # 将模型和影子模型送入指数移动平均
    model_ema_opt = WeightEMA(model, ema_model)

    # loss
    criterion_val = nn.CrossEntropyLoss()  # for eval
    criterion_train = MixUpLoss()   # for training

    opt = optim.Adam(model.parameters(), lr=HP.init_lr, weight_decay=0.001)  # optimizer with L2 regular

    start_epoch, step = 0, 0
    if args.c:  # 读取保存的模型
        checkpoint = torch.load(args.c)
        model.load_state_dict(checkpoint['model_state_dict'])
        ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
        opt.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        print('Resume From %s.' % args.c)
    else:
        print('Training from scratch!')

    model.train()
    eval_loss = 0.
    n_unlabeled = len(unlabeled_trainloader)  # 获取未标记的数据量

    # train loop
    for epoch in range(start_epoch, HP.epochs):
        print('Start epoch: %d, Step: %d' % (epoch, n_unlabeled))
        for i in range(n_unlabeled):  # one unlabeled data turn as an epoch
            # inputs_x: [N, 3, H, W], trg_x: [N,]
            inputs_x, trg_x = next(iter(labeled_trainloader))  # 获取一个batch的已标记数据和标签
            # inputs_u / inputs_u2 -> [N, 3, H, W]
            (inputs_u, inputs_u2), _ = next(iter(unlabeled_trainloader))  # 获取一个batch的未标记数据(两份)
            inputs_x, trg_x, inputs_u, inputs_u2 = inputs_x.to(HP.device), trg_x.long().to(HP.device), inputs_u.to(HP.device), inputs_u2.to(HP.device)

            # 对未标记数据进行标签猜测并锐化
            with torch.no_grad():
                out_u = model(inputs_u)  # Aug K=1, inference [N, 10]
                out_u2 = model(inputs_u2)  # Aug K=2, inference [N, 10]
                q = label_guessing(out_u, out_u2)  # average post distribution [N, 10]
                q = sharpen(q, T=HP.T)  # [N, 10],

            # 对已标记数据和标签以及未标记数据和猜测标签进行混合得到混合后的数据和标签
            # mixuped_x: [3*N, 3, H, W], mixuped_out: [3*N, 10]
            mixuped_x, mixuped_out = mixup(x=inputs_x, u=inputs_u, u2=inputs_u2, trg_x=trg_x, out_u=q, out_u2=q)

            # 对混合后的数据进行前向推理
            mixuped_logits = model(mixuped_x)  # [3*N, 10]
            logits_x = mixuped_logits[:HP.batch_size]  # [N, 10]
            logits_u = mixuped_logits[HP.batch_size:]  # [2*N, 10]

            # 获取已标记数据和未标记数据的损失函数
            loss_x, loss_u = criterion_train(logits_x, mixuped_out[:HP.batch_size], logits_u, mixuped_out[HP.batch_size:])
            # 合并成总的损失函数
            loss = loss_x + lambda_rampup(step, max_v=HP.lambda_u) * loss_u  # eq. (5)
            # 对损失函数进行反向传播
            logger.add_scalar('Loss/Train', loss, step)
            opt.zero_grad()
            loss.backward()
            opt.step()
            # 进行指数移动平均的模型参数计算
            model_ema_opt.step()

            if not step % HP.verbose_step:  # 模型评价
                acc1, acc5, eval_loss = evaluate(model, val_loader, criterion_val)
                logger.add_scalar('Loss/Dev', eval_loss, step)
                logger.add_scalar('Acc1', acc1, step)
                logger.add_scalar('Acc5', acc5, step)

            if not step % HP.save_step:  # 保存模型参数
                model_path = 'model_%d_%d.pth' % (epoch, step)
                save_checkpoint(model, ema_model, epoch, opt, os.path.join('./model_save', model_path))

            print('Epcoh: [%d/%d], step: %d, Train Loss: %.5f, Dev Loss: %.5f, Acc1: %.3f, Acc5: %.3f'%
                  (epoch, HP.epochs, step, loss.item(), eval_loss, acc1, acc5))
            step += 1
            logger.flush()
    logger.close()
展开阅读全文
加载中

作者的其它热门文章

打赏
1
1 收藏
分享
打赏
0 评论
1 收藏
1
分享
返回顶部
顶部