【CV中的Attention機制】BiSeNet中的FFM模組與ARM模組

  • 2020 年 2 月 21 日
  • 筆記

前言:之前介紹過一個語義分割中的注意力機制模組-scSE模組,效果很不錯。今天講的也是語義分割中使用到注意力機制的網路BiSeNet,這個網路有兩個模組,分別是FFM模組和ARM模組。其實現也很簡單,不過作者對注意力機制模組理解比較深入,提出的FFM模組進行的特徵融合方式也很新穎。

1. 簡介

語義分割需要豐富的空間資訊和相關大的感受野,目前很多語義分割方法為了達到實時推理的速度選擇犧牲空間解析度,這可能會導致比較差的模型表現。

BiSeNet(Bilateral Segmentation Network)中提出了空間路徑和上下文路徑:

  • 空間路徑用於保留語義資訊生成較高解析度的feature map(減少下取樣的次數)
  • 上下文路徑使用了快速下取樣的策略,用於獲取充足的感受野。
  • 提出了一個FFM模組,結合了注意力機制進行特徵融合。

本文主要關注的是速度和精度的權衡,對於解析度為2048×1024的輸入,BiSeNet能夠在NVIDIA Titan XP顯示卡上達到105FPS的速度,做到了實時語義分割。

2. 分析

提升語義分割速度主要有三種方法,如下圖所示:

  1. 通過resize的方式限定輸入大小,降低計算複雜度。缺點是空間細節有損失,尤其是邊界部分。
  2. 通過減少網路通道的個數來加快處理速度。缺點是會弱化空間資訊。
  3. 放棄最後階段的下取樣(如ENet)。缺點是模型感受野不足以覆蓋大物體,判別能力差。

語義分割中,U型結構也被廣泛使用,如下圖所示:

這種U型網路通過融合backbone不同層次的特徵,在U型結構中逐漸增加空間解析度,保留更多的細節特徵。不過有兩個缺點:

  1. 高解析度特徵圖計算量非常大,影響計算速度。
  2. 由於resize或者減少網路通道而丟失的空間資訊無法通過引入淺層而輕易復原。

3. 細節

下圖是BiSeNet的架構圖,從圖中可看到主要包括兩個部分:空間路徑和上下文路徑。

程式碼實現來自:https://github.com/ooooverflow/BiSeNet,其CP部分沒有使用Xception39而使用的ResNet18。

空間路徑SP

減少下取樣次數,只使用三個卷積層(stride=2)獲得1/8的特徵圖,由於它利用了較大尺度的特徵圖,所以可以編碼比較豐富的空間資訊。

class ConvBlock(torch.nn.Module):      def __init__(self,                   in_channels,                   out_channels,                   kernel_size=3,                   stride=2,                   padding=1):          super().__init__()          self.conv1 = nn.Conv2d(in_channels,                                 out_channels,                                 kernel_size=kernel_size,                                 stride=stride,                                 padding=padding,                                 bias=False)          self.bn = nn.BatchNorm2d(out_channels)          self.relu = nn.ReLU()        def forward(self, input):          x = self.conv1(input)          return self.relu(self.bn(x))      class Spatial_path(torch.nn.Module):      def __init__(self):          super().__init__()          self.convblock1 = ConvBlock(in_channels=3, out_channels=64)          self.convblock2 = ConvBlock(in_channels=64, out_channels=128)          self.convblock3 = ConvBlock(in_channels=128, out_channels=256)        def forward(self, input):          x = self.convblock1(input)          x = self.convblock2(x)          x = self.convblock3(x)          return x  

上下文路徑CP

為了增大感受野,論文提出上下文路徑,在Xception尾部添加全局平均池化層,從而提供更大的感受野。可以看出CP中進行了32倍的下取樣。(示例中CP部分使用的是ResNet18,不是論文中的xception39)

class resnet18(torch.nn.Module):      def __init__(self, pretrained=True):          super().__init__()          self.features = models.resnet18(pretrained=pretrained)          self.conv1 = self.features.conv1          self.bn1 = self.features.bn1          self.relu = self.features.relu          self.maxpool1 = self.features.maxpool          self.layer1 = self.features.layer1          self.layer2 = self.features.layer2          self.layer3 = self.features.layer3          self.layer4 = self.features.layer4        def forward(self, input):          x = self.conv1(input)          x = self.relu(self.bn1(x))          x = self.maxpool1(x)          feature1 = self.layer1(x)  # 1 / 4          feature2 = self.layer2(feature1)  # 1 / 8          feature3 = self.layer3(feature2)  # 1 / 16          feature4 = self.layer4(feature3)  # 1 / 32          # global average pooling to build tail          tail = torch.mean(feature4, 3, keepdim=True)          tail = torch.mean(tail, 2, keepdim=True)          return feature3, feature4, tail  

組件融合

為了SP和CP更好的融合,提出了特徵融合模組FFM還有注意力優化模組ARM。

ARM:

ARM使用在上下文路徑中,用於優化每一階段的特徵,使用全局平均池化指導特徵學習,計算成本可以忽略。其具體實現方式與SE模組很類似,屬於通道注意力機制。

class AttentionRefinementModule(torch.nn.Module):      def __init__(self, in_channels, out_channels):          super().__init__()          self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)          self.bn = nn.BatchNorm2d(out_channels)          self.sigmoid = nn.Sigmoid()          self.in_channels = in_channels          self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))        def forward(self, input):          # global average pooling          x = self.avgpool(input)          assert self.in_channels == x.size(              1), 'in_channels and out_channels should all be {}'.format(                  x.size(1))          x = self.conv(x)          # x = self.sigmoid(self.bn(x))          x = self.sigmoid(x)          # channels of input and x should be same          x = torch.mul(input, x)          return x  

FFM:

特徵融合模組用於融合CP和SP提供的輸出特徵,由於兩路特徵並不相同,所以不能對這兩部分特徵進行簡單的加權。SP提供的特徵是低層次的(8×down),CP提供的特徵是高層語義的(32×down)。

將兩個部分特徵圖通過concate方式疊加,然後使用類似SE模組的方式計算加權特徵,起到特徵選擇和結合的作用。(這種特徵融合方式值得學習)

class FeatureFusionModule(torch.nn.Module):      def __init__(self, num_classes, in_channels):          super().__init__()          self.in_channels = in_channels          self.convblock = ConvBlock(in_channels=self.in_channels,                                     out_channels=num_classes,                                     stride=1)          self.conv1 = nn.Conv2d(num_classes, num_classes, kernel_size=1)          self.relu = nn.ReLU()          self.conv2 = nn.Conv2d(num_classes, num_classes, kernel_size=1)          self.sigmoid = nn.Sigmoid()          self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))        def forward(self, input_1, input_2):          x = torch.cat((input_1, input_2), dim=1)          assert self.in_channels == x.size(              1), 'in_channels of ConvBlock should be {}'.format(x.size(1))          feature = self.convblock(x)          x = self.avgpool(feature)            x = self.relu(self.conv1(x))          x = self.sigmoid(self.conv2(x))          x = torch.mul(feature, x)          x = torch.add(x, feature)          return x  

BiSeNet網路整個模型:

class BiSeNet(torch.nn.Module):      def __init__(self, num_classes, context_path):          super().__init__()          self.spatial_path = Spatial_path()          self.context_path = build_contextpath(name=context_path)          if context_path == 'resnet101':              self.attention_refinement_module1 = AttentionRefinementModule(                  1024, 1024)              self.attention_refinement_module2 = AttentionRefinementModule(                  2048, 2048)              self.supervision1 = nn.Conv2d(in_channels=1024,                                            out_channels=num_classes,                                            kernel_size=1)              self.supervision2 = nn.Conv2d(in_channels=2048,                                            out_channels=num_classes,                                            kernel_size=1)              self.feature_fusion_module = FeatureFusionModule(num_classes, 3328)            elif context_path == 'resnet18':              self.attention_refinement_module1 = AttentionRefinementModule(                  256, 256)              self.attention_refinement_module2 = AttentionRefinementModule(                  512, 512)              self.supervision1 = nn.Conv2d(in_channels=256,                                            out_channels=num_classes,                                            kernel_size=1)              self.supervision2 = nn.Conv2d(in_channels=512,                                            out_channels=num_classes,                                            kernel_size=1)              self.feature_fusion_module = FeatureFusionModule(num_classes, 1024)          else:              print('Error: unspport context_path network n')          self.conv = nn.Conv2d(in_channels=num_classes,                                out_channels=num_classes,                                kernel_size=1)        def forward(self, input):          sx = self.spatial_path(input)          cx1, cx2, tail = self.context_path(input)          cx1 = self.attention_refinement_module1(cx1)          cx2 = self.attention_refinement_module2(cx2)          cx2 = torch.mul(cx2, tail)          cx1 = torch.nn.functional.interpolate(cx1,                                                size=sx.size()[-2:],                                                mode='bilinear')          cx2 = torch.nn.functional.interpolate(cx2,                                                size=sx.size()[-2:],                                                mode='bilinear')          cx = torch.cat((cx1, cx2), dim=1)          if self.training == True:              cx1_sup = self.supervision1(cx1)              cx2_sup = self.supervision2(cx2)              cx1_sup = torch.nn.functional.interpolate(cx1_sup,                                                        size=input.size()[-2:],                                                        mode='bilinear')              cx2_sup = torch.nn.functional.interpolate(cx2_sup,                                                        size=input.size()[-2:],                                                        mode='bilinear')          result = self.feature_fusion_module(sx, cx)          result = torch.nn.functional.interpolate(result,                                                   scale_factor=8,                                                   mode='bilinear')          result = self.conv(result)          if self.training == True:              return result, cx1_sup, cx2_sup          return result  

4. 實驗

使用了Xception39處理實時語義分割任務,在CityScapes, CamVid和COCO stuff三個數據集上進行評估。

消融實驗:

測試了basemodel xception39,參數量要比ResNet18小得多,同時MIOU只略低於與ResNet18。

以上是BiSeNet各個模組的消融實驗,可以看出,每個模組都是有效的。

統一使用了640×360解析度的圖片進行對比參數量和FLOPS狀態。

上表對BiSeNet網路和其他網路就MIOU和FPS上進行比較,可以看出該方法相比於其他方法在速度和精度方面有很大的優越性。

在使用ResNet101等比較深的網路作為backbone的情況下,效果也是超過了其他常見的網路,這證明了這個模型的有效性。

5. 結論

BiSeNet 旨在同時提升實時語義分割的速度與精度,它包含兩路網路:Spatial Path 和 Context Path。Spatial Path 被設計用來保留原影像的空間資訊,Context Path 利用輕量級模型和全局平均池化快速獲取大感受野。由此,在 105 fps 的速度下,該方法在 Cityscapes 測試集上取得了 68.4% mIoU 的結果。

歡迎關注GiantPandaCV, 在這裡你將看到獨家的深度學習分享,堅持原創,每天分享我們學習到的新鮮知識。( • ̀ω•́ )✧