论文笔记——[CVPR 2022]Vision Transformer with Deformable Attention

论文笔记——[CVPR 2022]Vision Transformer with Deformable Attention

7月 3, 2022 阅读 2657 字数 3948 评论 2 喜欢 0

创新点:
①以数据依赖的方式在自注意力计算中选择K和V对。
②通过一个网络学习offset坐标,采用双线性插值计算位置。

与其他网络的对比(Attention模块)


(a) 对所有patch采用相同的自注意力(计算量和内存占用巨大)
(b) 对划分的窗口采用自注意力(稀疏注意力,对更远的数据不可知,注意力窗口与数据无关)
(c) 可变形卷积,寻找需要卷积的区域(卷积区域不能扩大)
(d) 本文的方法,自己学习对应的注意力区域

可变形卷积(DCN):

假设卷积区域为3×3,先对整体特征图做3×3conv得到大小为h×w×(2×k×k)的位置偏移矩阵,任意一点可以获得2×k×k的位置偏移信息(浮点数变量通过线性插值法算出具体的值),对偏移后的值做普通卷积操作即可。

Attention结构


(a) 均匀放置参考点,数量为group,通过一个Offset网络学习参考点的偏移,偏移后的坐标有两个作用:

  • 找到最近的四个坐标,双线性插值计算得到patch
  • 计算相对位置偏置作为B

(b) Offset计算网络,由一个5×5 depthwise卷积(stride=r)、一个GELU激活层和一个1×1卷积(out_channel=2, 无bias)构成。

双线性插值

双线性插值后的x通过以下公式得到:

其中是为了找到离偏移后的点距离在一个单位以内,大于一个单位的计算会变成0,而rx和ry是所有的坐标对应的值。
双线性插值可以用下图理解:

相对位置偏置

由于特征图是全尺寸H×W的,相对位置偏置矩阵也应该是全尺寸的\hat{B} \in \mathbb{R}^{{2H-1} \times {2W-1}}

网络总体结构


作者通过实验证明,前两个stage用swin,后两个stage用dat得到的结果最好。

各stage的参数:

CODE

DAttention的代码

    def forward(self, x):

        B, C, H, W = x.size()
        dtype, device = x.dtype, x.device

        # 通过1*1卷积投影得到q
        q = self.proj_q(x)
        # 把q分为g个组
        q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)
        # 经过offset网络
        offset = self.conv_offset(q_off) # B * g 2 Hg Wg
        Hk, Wk = offset.size(2), offset.size(3)
        # 计算特征图大小
        n_sample = Hk * Wk

        # 计算偏移量,先把偏移量通过tanh把范围归到-1~1,再除以对应的高或者宽,乘以最大偏移量offset_range_factor
        if self.offset_range_factor > 0:
            offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk], device=device).reshape(1, 2, 1, 1)
            offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

        offset = einops.rearrange(offset, 'b p h w -> b h w p')
        # 得到等距位置信息矩阵
        reference = self._get_ref_points(Hk, Wk, B, dtype, device)

        if self.no_off:
            offset = offset.fill(0.0)

        if self.offset_range_factor >= 0:
            # 位置信息矩阵加上范围为-offset_range_factor~+offset_range_factor的偏移
            pos = offset + reference
        else:
            pos = (offset + reference).tanh()

        # grid_sample提供一个input以及一个网格,然后根据grid中每个位置提供的坐标信息(input中pixel的坐标),
        # 将input中对应位置的像素值填充到grid指定的位置,得到最终的输出。
        # 位置非整数则采用线性插值法
        x_sampled = F.grid_sample(
            input=x.reshape(B * self.n_groups, self.n_group_channels, H, W), 
            grid=pos[..., (1, 0)], # y, x -> x, y
            mode='bilinear', align_corners=True) # B * g, Cg, Hg, Wg

        x_sampled = x_sampled.reshape(B, C, 1, n_sample)

        # q之前已经得到,通过采样后的x_sampled计算k和v
        q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
        k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
        v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)

        # q乘k
        attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns
        # 乘以缩放倍数
        attn = attn.mul(self.scale)

        # 使用偏置
        if self.use_pe:

            if self.dwc_pe:
                # 只在最后加上偏置
                residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels, H * W)
            elif self.fixed_pe:
                # 固定偏置(学习所有可能的q和k的固定位置)
                rpe_table = self.rpe_table
                attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                attn = attn + attn_bias.reshape(B * self.n_heads, H * W, self.n_sample)
            else:
                # 相对偏置(学习所有可能的q和k的相对位置)
                rpe_table = self.rpe_table
                rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)

                q_grid = self._get_ref_points(H, W, B, dtype, device)

                displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5)

                attn_bias = F.grid_sample(
                    input=rpe_bias.reshape(B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1),
                    grid=displacement[..., (1, 0)],
                    mode='bilinear', align_corners=True
                ) # B * g, h_g, HW, Ns

                attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)

                # qk之后加上位置偏置
                attn = attn + attn_bias

        attn = F.softmax(attn, dim=2)
        attn = self.attn_drop(attn)

        # qk之后经过softmax再乘v
        out = torch.einsum('b m n, b c n -> b c m', attn, v)

        if self.use_pe and self.dwc_pe:
            out = out + residual_lepe
        out = out.reshape(B, C, H, W)

        # 再过一个1*1卷积
        y = self.proj_drop(self.proj_out(out))

        return y, pos.reshape(B, self.n_groups, Hk, Wk, 2), reference.reshape(B, self.n_groups, Hk, Wk, 2)

MLP层

class TransformerMLPWithConv(nn.Module):

    def __init__(self, channels, expansion, drop):

        super().__init__()

        self.dim1 = channels
        self.dim2 = channels * expansion
        self.linear1 = nn.Conv2d(self.dim1, self.dim2, 1, 1, 0)
        self.drop1 = nn.Dropout(drop, inplace=True)
        self.act = nn.GELU()
        self.linear2 = nn.Conv2d(self.dim2, self.dim1, 1, 1, 0) 
        self.drop2 = nn.Dropout(drop, inplace=True)
        self.dwc = nn.Conv2d(self.dim2, self.dim2, 3, 1, 1, groups=self.dim2)

    def forward(self, x):

        x = self.drop1(self.act(self.dwc(self.linear1(x))))
        x = self.drop2(self.linear2(x))

        return x

其中默认drop都是0

评论列表

  1. zzzzzzzz说道:

    请问有人用DAT的代码跑实验吗?我跑出的结果不理想

    1. latrawy说道:

      我没有试过,可变形卷积本身就挺难训的,或许跟你跑的任务有关?

发表评论

您的电子邮箱地址不会被公开。