使用注意力机制来做医学图像分割的解释和Pytorch实现

原创
2020/08/14 09:08
阅读数 1K

点击上方“AI公园”,关注公众号,选择加“星标“或“置顶”


作者:Léo Fillioux

编译:ronghuaiyang

导读

对两篇近期的使用注意力机制进行分割的文章进行了分析,并给出了简单的Pytorch实现。

从自然语言处理开始,到最近的计算机视觉任务,注意力机制一直是深度学习研究中最热门的领域之一。在这篇文章中,我们将集中讨论注意力是如何影响医学图像分割的最新架构的。为此,我们将描述最近两篇论文中介绍的架构,并尝试给出一些关于这两篇文章中提到的方法的直觉,希望它能给你一些想法,让你能够将注意力机制应用到自己的问题上。我们还将看到简单的PyTorch实现。

医学图像分割与自然图像的区别主要有两点:

  • 大多数医学图像都非常相似,因为它们是在标准化设置中拍摄的,这意味着在图像的方向、位置、像素范围等方面几乎没有变化。
  • 通常在正样本像素(或体素)和负样本像素之间存在很大的不平衡,例如在尝试分割肿瘤时。

注意:当然,代码和解释都是对论文中描述的复杂架构的简化,其目的主要是给出一个关于做了什么的直觉和一个好的想法,而不是解释每一个细节。

1. Attention UNet

UNet是用于分割的主要架构,目前在分割方面的大多数进展都使用这种架构作为骨干。在本文中,作者提出了一种将注意力机制应用于标准UNet的方法。

1.1. 提出了什么方法

该结构使用标准UNet作为骨干,并且不改变收缩路径。改变的是扩展路径,更准确地说,注意力机制被整合到跳转连接中。

attention UNet的框图,扩展路径block用红色框出

为了解释展开路径的block是如何工作的,让我们把来自前一个block的输入称为g,以及来自扩展路径的skip链接称为x。下面的式子描述了这个模块是如何工作的。

upsample块非常简单,而ConvBlock只是由两个(convolution + batch norm + ReLU)块组成的序列。唯一需要解释的是注意力。

注意力block的框图。这里的维度假设输入图像维度为3。
  • xg都被送入到1x1卷积中,将它们变为相同数量的通道数,而不改变大小
  • 在上采样操作后(有相同的大小),他们被累加并通过ReLU
  • 通过另一个1x1的卷积和一个sigmoid,得到一个0到1的重要性分数,分配给特征图的每个部分
  • 然后用这个注意力图乘以skip输入,产生这个注意力块的最终输出

1.2. 为什么这样是有效的

在UNet中,可将收缩路径视为编码器,而将扩展路径视为解码器。UNet的有趣之处在于,跳跃连接允许在解码器期间直接使用由编码器提取的特征。这样,在“重建”图像的掩模时,网络就学会了使用这些特征,因为收缩路径的特征与扩展路径的特征是连接在一起的。

在此连接之前应用一个注意力块,可以让网络对跳转连接相关的特征施加更多的权重。它允许直接连接专注于输入的特定部分,而不是输入每个特征。

将注意力分布乘上跳转连接特征图,只保留重要的部分。这种注意力分布是从所谓的query(输入)和value(跳跃连接)中提取出来的。注意力操作允许有选择地选择包含在值中的信息。此选择基于query。

总结:输入和跳跃连接用于决定要关注跳跃连接的哪些部分。然后,我们使用skip连接的这个子集,以及标准展开路径中的输入。

1.3. 简短的实现

下面的代码定义了注意力块(简化版)和用于UNet扩展路径的“up-block”。“down-block”与原UNet一样。

class AttentionBlock(nn.Module):
    def __init__(self, in_channels_x, in_channels_g, int_channels):
        super(AttentionBlock, self).__init__()
        self.Wx = nn.Sequential(nn.Conv2d(in_channels_x, int_channels, kernel_size = 1),
                                nn.BatchNorm2d(int_channels))
        self.Wg = nn.Sequential(nn.Conv2d(in_channels_g, int_channels, kernel_size = 1),
                                nn.BatchNorm2d(int_channels))
        self.psi = nn.Sequential(nn.Conv2d(int_channels, 1, kernel_size = 1),
                                 nn.BatchNorm2d(1),
                                 nn.Sigmoid())
    
    def forward(self, x, g):
        # apply the Wx to the skip connection
        x1 = self.Wx(x)
        # after applying Wg to the input, upsample to the size of the skip connection
        g1 = nn.functional.interpolate(self.Wg(g), x1.shape[2:], mode = 'bilinear', align_corners = False)
        out = self.psi(nn.ReLU()(x1 + g1))
        out = nn.Sigmoid()(out)
        return out*x

class AttentionUpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AttentionUpBlock, self).__init__()
        self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 2, stride = 2)
        self.attention = AttentionBlock(out_channels, in_channels, int(out_channels / 2))
        self.conv_bn1 = ConvBatchNorm(in_channels+out_channels, out_channels)
        self.conv_bn2 = ConvBatchNorm(out_channels, out_channels)
    
    def forward(self, x, x_skip):
        # note : x_skip is the skip connection and x is the input from the previous block
        # apply the attention block to the skip connection, using x as context
        x_attention = self.attention(x_skip, x)
        # upsample x to have th same size as the attention map
        x = nn.functional.interpolate(x, x_skip.shape[2:], mode = 'bilinear', align_corners = False)
        # stack their channels to feed to both convolution blocks
        x = torch.cat((x_attention, x), dim = 1)
        x = self.conv_bn1(x)
        return self.conv_bn2(x)
在使用注意力时,注意力块和UNet扩展路径块的简单的实现。

注意:ConvBatchNorm是一个由Conv2d、BatchNorm2d和ReLU激活函数组成的sequence。

2. Multi-scale guided attention

我们将要讨论的第二个架构比第一个架构更有独创性。它不依赖于UNet架构,而是依赖于特征提取,然后跟一个引导注意力块。

所提出的方法的Block图

第一部分是从图像中提取特征。为此,我们将输入图像输入到一个预先训练好的ResNet中,提取4个不同层次的特征图。这很有趣,因为低层次的特征往往出现在网络的开始阶段,而高层次的特性往往出现在网络的结束阶段,所以我们将能够访问到多种尺度的特征。使用bilinear插值将所有的特征图上采样到最大的一个。这给了我们4个相同大小的特征图,它们被连接并送入一个卷积块。这个convolutional block (multi-scale feature map)的输出与4个feature map的每一个都连接在一起,这给出了我们的attention blocks的输入,这个输入比之前的要复杂一些。

2.1. 提出了什么

引导注意力块依赖于位置和通道注意力模块,我们从总体描述开始。

位置和通道注意力模块的框图

我们将尝试理解这些模块中发生了什么,但是我们不会详细介绍这两个模块中的每个操作(可以通过下面的代码部分理解)。

这两个块实际上非常相似,它们之间的唯一区别在于从通道还是位置提取信息。在flatten之前进行卷积会使位置更加重要,因为在卷积过程中通道的数量会减少。在通道注意力模块中,在reshape的过程中,原有通道数量被保留,这样更多的权重给到了通道上。

在每个block中,需要注意的是,最上面的两个分支负责提取具体的注意力分布。例如,在位置注意力模块中,我们有一个(WH)x(WH)的注意力分布,其中*(i, j)元素表示位置i对位置j*的影响有多大。在通道块中,我们有一个CxC注意力分布,它告诉我们一个通道对另一个的影响有多大。在每个模块的第三个分支中,将这个特定的注意分布乘以输入的变换,得到通道或位置的注意力分布。如前一篇文章所述,在给定多尺度特征的背景下,将注意力分布乘以输入来提取输入的相关信息。然后对这两个模块的输出进行逐元素的相加,给出最终的自注意力特征。现在,让我们看看如何在全局框架中使用这两个模块的输出。

引导注意模块的2个细化步骤的框图

引导注意力为每个尺度建立一个连续的多个细化步骤(在提出的结构中有4个尺度)。输入特征图被送至位置和通道输出模块,输出单个特征图。它还通过了一个自动编码器,该编码器对输入进行重建。在每个block中,注意力图是由这两个输出相乘产生的。然后将此注意力图与之前生成的多尺度特征图相乘。因此,输出表示了我们需要关注特定的尺度的哪个部分。然后,通过将一个block的输出与多尺度的注意力图连接起来,并将其作为下一个block的输入,你就可以获得这样的引导注意力模块的序列。

两个相加的损失是必要的,以确保细化步骤工作正确:

  • 标准重建损失,以确保自动编码器正确重建输入的特征图
  • 引导损失,它试图最小化输入的两个后面的潜在表示之间的距离

之后,每个注意力特征通过卷积块来预测mask。为了得到最终的预测结果,需要对四个mask进行平均,这可以看作是不同尺度特征下模型的一种集成。

2.2. 为什么这样是有效的

由于这个结构比前一个复杂得多,所以很难理解注意力模块背后的情况。下面是我对各个块的贡献的理解。

位置注意模块试图根据输入图像的多尺度表示来指定要聚焦的特定尺度特征在哪个位置。通道注意模块通过指定各个通道需要注意多少来做同样的事情。在任何一个block中使用的具体操作是为了给予通道或位置信息一个注意力分布,分配哪些地方是更重要的。结合这两个模块,我们得到了一个对每个位置-通道对打分的注意力图,即特征图中的每个元素。

autoencoder用来确保feature map的后续的表示在每一步之间都没有完全改变。由于潜空间是低维的,因此只提取关键信息。我们不希望将此信息从一个细化步骤更改为下一个细化步骤,我们只希望进行较小的调整。这些在潜在表示中不会被看到。

使用一系列的引导注意力模块,可以使最终的注意力图得到细化,并逐步使噪音消失,给予真正重要的区域更多的权重。

将几个这样的多尺度网络集成起来,可以使网络同时具有全局和局部特征。然后将这些特征组合成多尺度特征图。将注意力与每个特定的尺度一起应用到多尺度特征图上,可以更好地理解哪些特征对最终的输出更有价值。

2.3. 简短的实现

class PositionAttentionModule(nn.Module):
    def __init__(self, in_channels):
        super(PositionAttentionModule, self).__init__()
        self.first_branch_conv = nn.Conv2d(in_channels, int(in_channels/8), kernel_size = 1)
        self.second_branch_conv = nn.Conv2d(in_channels, int(in_channels/8), kernel_size = 1)
        self.third_branch_conv = nn.Conv2d(in_channels, in_channels, kernel_size = 1)
        self.output_conv = nn.Conv2d(in_channels, in_channels, kernel_size = 1)
    
    def forward(self, F):
        # first branch
        F1 = self.first_branch_conv(F)                  # (C/8, W, H)
        F1 = F1.reshape((F1.size(0), F1.size(1), -1))   # (C/8, W*H)
        F1 = torch.transpose(F1, -2-1)                # (W*H, C/8)
        # second branch
        F2 = self.second_branch_conv(F)                 # (C/8, W, H)
        F2 = F2.reshape((F2.size(0), F2.size(1), -1))   # (C/8, W*H)
        F2 = nn.Softmax(dim = -1)(torch.matmul(F1, F2)) # (W*H, W*H)
        # third branch
        F3 = self.third_branch_conv(F)                  # (C, W, H)
        F3 = F3.reshape((F3.size(0), F3.size(1), -1))   # (C, W*H)
        F3 = torch.matmul(F3, F2)                       # (C, W*H)
        F3 = F3.reshape(F.shape)                        # (C, W, H)
        return self.output_conv(F3*F)

class ChannelAttentionModule(nn.Module):
    def __init__(self, in_channels):
        super(ChannelAttentionModule, self).__init__()
        self.output_conv = nn.Conv2d(in_channels, in_channels, kernel_size = 1)
    
    def forward(self, F):
        # first branch
        F1 = F.reshape((F.size(0), F.size(1), -1))      # (C, W*H)
        F1 = torch.transpose(F1, -2-1)                # (W*H, C)
        # second branch
        F2 = F.reshape((F.size(0), F.size(1), -1))      # (C, W*H)
        F2 = nn.Softmax(dim = -1)(torch.matmul(F2, F1)) # (C, C)
        # third branch
        F3 = F.reshape((F.size(0), F.size(1), -1))      # (C, W*H)
        F3 = torch.matmul(F2, F3)                       # (C, W*H)
        F3 = F3.reshape(F.shape)                        # (C, W, H)
        return self.output_conv(F3*F)

class GuidedAttentionModule(nn.Module):
    def __init__(self, in_channels_F, in_channels_Fms):
        super(GuidedAttentionModule, self).__init__()
        in_channels = in_channels_F + in_channels_Fms
        self.pam = PositionAttentionModule(in_channels)
        self.cam = ChannelAttentionModule(in_channels)
        self.encoder = nn.Sequential(nn.Conv2d(in_channels, 2*in_channels, kernel_size = 3),
                                     nn.BatchNorm2d(2*in_channels),
                                     nn.Conv2d(2*in_channels, 4*in_channels, kernel_size = 3),
                                     nn.BatchNorm2d(4*in_channels),
                                     nn.ReLU())
        self.decoder = nn.Sequential(nn.ConvTranspose2d(4*in_channels, 2*in_channels, kernel_size = 3),
                                     nn.BatchNorm2d(2*in_channels),
                                     nn.ConvTranspose2d(2*in_channels, in_channels, kernel_size = 3),
                                     nn.BatchNorm2d(in_channels),
                                     nn.ReLU())
        self.attention_map_conv = nn.Sequential(nn.Conv2d(in_channels, in_channels_Fms, kernel_size = 1),
                                                nn.BatchNorm2d(in_channels_Fms),
                                                nn.ReLU())
        
    def forward(self, F, F_ms):
        F = torch.cat((F, F_ms), dim = 1)         # concatenate the extracted feature map with the multi scale feature map
        F_pcam = self.pam(F) + self.cam(F)        # sum the ouputs of the position and channel attention modules
        F_latent = self.encoder(F)                # latent-space representation, used for the guided loss
        F_reconstructed = self.decoder(F_latent)  # output of the autoencoder, used for the reconstruction loss
        F_output = self.attention_map_conv(F_reconstructed * F_pcam)
        F_output = F_output * F_ms
        return F_output, F_reconstructed, F_latent
位置注意模块、通道注意模块和一个引导注意模块的简短的实现。

要点

那么,我们可以从这些文章中得到什么呢?注意力可以被看作是一种机制,它有助于基于网络的上下文指出需要关注的特征。

在UNet中,考虑到在扩展路径中提取的特征,在收缩路径中提取哪些特征是需要重点关注的。这有助于让跳跃连接更有意义,即传递相关信息,而不是每个提取的特征。在第二篇文章中,考虑到我们正在处理的当前的尺度,我们应该关注哪些多尺度特征。

这个概念可以应用到很多问题上,我认为多看几个例子有助于更好地理解注意力是如何适应不同问题的。


END

英文原文:https://towardsdatascience.com/using-attention-for-medical-image-segmentation-dd78825eaac6

请长按或扫描二维码关注本公众号

喜欢的话,请给我个好看吧


本文分享自微信公众号 - AI公园(AI_Paradise)。
如有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一起分享。

展开阅读全文
打赏
0
0 收藏
分享
加载中
更多评论
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部