论文笔记——[AAAI 2022]Less is More: Pay Less Attention in Vision Transformers

论文笔记——[AAAI 2022]Less is More: Pay Less Attention in Vision Transformers

7月 20, 2022 阅读 2705 字数 1706 评论 0 喜欢 3

创新点:
①在浅层用MLP编码局部特征
②在深层用自注意力捕获长距离依赖
③可变形的token融合模块,以非均匀的方式自适应地融合patch。

特点:减少计算成本

动机:

  • 先前在CNNs和Transformers上的研究表明,浅层关注局部特征,而更深层则倾向于捕获高级语义或全局关系。因此作者认为在早期不必要采用Transformers。
  • 越少的头表达能力越弱(类似1×1卷积)(具有一个头的MSA只能近似于一个FC层)
  • 采用特征金字塔的方式,在早期采用MLP,后期采用自注意力,可以避免早期高分辨率图像带来的巨大计算成本和内存占用。

整体结构


LIT的整体架构。该模型分为四个stages,我们在前两个stages应用MLP块,在后两个stages采用标准Transformer块。DTM表示所提出的可变形的token融合模块。


可变形的token融合模块Deformable Token Merging(DTM)

灵感来自可变形卷积,原始的可变形卷积可以通过公式表示为:

DC(\textbf{X}_{p,:}) = \sum_{k\in [K \times K]}\textbf{X}_{p+g(k)+\Delta g(k),:}\textbf{W}_{g(k),:,:}

与普通卷积相比可变形卷积学习了一个\Delta g(k)偏移量,这同样可以应用于特征图的生成。为了合并patches,采用可变形卷积方式:

DTM(\textbf{X}=GELU(BN(DC(\textbf(X)))))

实验结果


在参数差不多的情况下准确率相比swin提了0.几

可变形卷积token模块的有效性


绿框为原本卷积窗口,红色代表变形后的卷积窗口位置

CODE

可变形的token融合模块(调用可变形卷积的算子)

class DeformablePatchMerging(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.kernel_size = 2
        self.stride = 2
        self.padding = 0
        self.c_in = dim
        self.c_out = dim*2
        self.dconv = DeformConv2dPack(dim, dim*2, kernel_size=2, stride=2, padding=0)
        self.norm_layer = nn.BatchNorm2d(dim*2)
        self.act_layer = nn.GELU()

    def forward(self, x, return_offset=False):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
        x, offset = self.dconv(x, return_offset=False)
        x = self.act_layer(self.norm_layer(x)).flatten(2).transpose(1, 2)
        if return_offset:
            return x, offset
        else:
            return x

发表评论

您的电子邮箱地址不会被公开。