论文笔记——Fast Vision Transformers with HiLo Attention

论文笔记——Fast Vision Transformers with HiLo Attention

7月 21, 2022 阅读 3100 字数 3340 评论 2 喜欢 1

创新点:
①高频捕捉局部精细数据,低频聚焦全局结构
②为了区分不同频率的独特性质,让attention中的不同头分为两组,分别进入进入高\低频注意力模块,高频通过局部窗口计算自注意力,而低频通过平均池化K和V来计算全局自注意力
③在GPU上的flop、速度和内存消耗优于现有的注意力机制。

现有方法存在的问题:尽管现有的自注意力机制在低分辨率图像上效果很好,但是由于二次复杂度,在高分辨率图像上速度会慢得多。

Attention结构

高频注意力(Hi-Fi):由于高频数据关注局部细节,应用全局注意力可能是冗余的,于是采用局部窗口自注意(例如2×2窗口)捕获细粒度,从而节省了显著的计算复杂度。(无滑动或者多尺度窗口)

低频注意力(Lo-Fi):全局注意力有助于捕获低频,为了减少计算复杂度,对每个窗口应用平均池化来获得\textbf{K}\in\mathbb{R}^{N/s^2\times D_h}\textbf{V}\in\mathbb{R}^{N/s^2\times D_h},其中s^2是窗口大小。

分头策略:(1−\alpha )N_h个头将用于高频,其他\alpha N_h个头将用于低频。

相比于swin或者vit,本文提出的方法每个自注意力模块都显著减少了计算量。

位置编码

在每个FFN中加入零填充的3×3的DW卷积作为相对位置编码。

总体结构

  • ConvFFN代表应用了DW卷积的FFN。
  • DTM来自于LITv1中的可变形token融合模块。

实验结果


参数量差别不大,提升了0.2个点
相比于v1,提升了1个多点

CODE

class HiLo(nn.Module):
    """
    HiLo Attention
    Link: https://arxiv.org/abs/2205.13213
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=2, alpha=0.5):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        head_dim = int(dim/num_heads)
        self.dim = dim

        # self-attention heads in Lo-Fi
        # 低频特征
        self.l_heads = int(num_heads * alpha)
        # token dimension in Lo-Fi
        self.l_dim = self.l_heads * head_dim

        # self-attention heads in Hi-Fi
        # 高频特征
        self.h_heads = num_heads - self.l_heads
        # token dimension in Hi-Fi
        self.h_dim = self.h_heads * head_dim

        # local window size. The `s` in our paper.
        self.ws = window_size

        if self.ws == 1:
            # 窗口是1的情况等于标准多头自注意力
            self.h_heads = 0
            self.h_dim = 0
            self.l_heads = num_heads
            self.l_dim = dim

        self.scale = qk_scale or head_dim ** -0.5

        # Low frequence attention (Lo-Fi)
        if self.l_heads > 0:
            # 窗口大小不等于1则根据窗口做平均池化
            if self.ws != 1:
                self.sr = nn.AvgPool2d(kernel_size=window_size, stride=window_size)
            self.l_q = nn.Linear(self.dim, self.l_dim, bias=qkv_bias)
            self.l_kv = nn.Linear(self.dim, self.l_dim * 2, bias=qkv_bias)
            self.l_proj = nn.Linear(self.l_dim, self.l_dim)

        # High frequence attention (Hi-Fi)
        if self.h_heads > 0:
            self.h_qkv = nn.Linear(self.dim, self.h_dim * 3, bias=qkv_bias)
            self.h_proj = nn.Linear(self.h_dim, self.h_dim)

    # 高频自注意力
    def hifi(self, x):
        B, H, W, C = x.shape
        h_group, w_group = H // self.ws, W // self.ws

        # 窗口数量
        total_groups = h_group * w_group

        x = x.reshape(B, h_group, self.ws, w_group, self.ws, C).transpose(2, 3)

        # 只在窗口内做自注意力
        qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.h_dim // self.h_heads).permute(3, 0, 1, 4, 2, 5)
        q, k, v = qkv[0], qkv[1], qkv[2]  # B, hw, n_head, ws*ws, head_dim

        attn = (q @ k.transpose(-2, -1)) * self.scale  # B, hw, n_head, ws*ws, ws*ws
        attn = attn.softmax(dim=-1)
        attn = (attn @ v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, self.h_dim)
        x = attn.transpose(2, 3).reshape(B, h_group * self.ws, w_group * self.ws, self.h_dim)

        x = self.h_proj(x)
        return x

    def lofi(self, x):
        B, H, W, C = x.shape

        # 全局自注意力
        q = self.l_q(x).reshape(B, H * W, self.l_heads, self.l_dim // self.l_heads).permute(0, 2, 1, 3)

        if self.ws > 1:
            x_ = x.permute(0, 3, 1, 2)
            # 对kv进行平均池化
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            kv = self.l_kv(x_).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.l_dim)
        x = self.l_proj(x)
        return x

    def forward(self, x):
        B, N, C = x.shape
        H = W = int(N ** 0.5)

        x = x.reshape(B, H, W, C)

        if self.h_heads == 0:
            x = self.lofi(x)
            return x.reshape(B, N, C)

        if self.l_heads == 0:
            x = self.hifi(x)
            return x.reshape(B, N, C)

        hifi_out = self.hifi(x)
        lofi_out = self.lofi(x)

        x = torch.cat((hifi_out, lofi_out), dim=-1)
        x = x.reshape(B, N, C)
        return x

评论列表

  1. 是新哥的小迷弟呀说道:

    新哥!

发表评论

您的电子邮箱地址不会被公开。