创新点:
①高频捕捉局部精细数据,低频聚焦全局结构
②为了区分不同频率的独特性质,让attention中的不同头分为两组,分别进入进入高\低频注意力模块,高频通过局部窗口计算自注意力,而低频通过平均池化K和V来计算全局自注意力
③在GPU上的flop、速度和内存消耗优于现有的注意力机制。
现有方法存在的问题:尽管现有的自注意力机制在低分辨率图像上效果很好,但是由于二次复杂度,在高分辨率图像上速度会慢得多。
Attention结构

高频注意力(Hi-Fi):由于高频数据关注局部细节,应用全局注意力可能是冗余的,于是采用局部窗口自注意(例如2×2窗口)捕获细粒度,从而节省了显著的计算复杂度。(无滑动或者多尺度窗口)
低频注意力(Lo-Fi):全局注意力有助于捕获低频,为了减少计算复杂度,对每个窗口应用平均池化来获得\textbf{K}\in\mathbb{R}^{N/s^2\times D_h}和\textbf{V}\in\mathbb{R}^{N/s^2\times D_h},其中s^2是窗口大小。
分头策略:(1−\alpha )N_h个头将用于高频,其他\alpha N_h个头将用于低频。
相比于swin或者vit,本文提出的方法每个自注意力模块都显著减少了计算量。
位置编码
在每个FFN中加入零填充的3×3的DW卷积作为相对位置编码。
总体结构

- ConvFFN代表应用了DW卷积的FFN。
- DTM来自于LITv1中的可变形token融合模块。
实验结果

参数量差别不大,提升了0.2个点
相比于v1,提升了1个多点
CODE
class HiLo(nn.Module):
"""
HiLo Attention
Link: https://arxiv.org/abs/2205.13213
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=2, alpha=0.5):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
head_dim = int(dim/num_heads)
self.dim = dim
# self-attention heads in Lo-Fi
# 低频特征
self.l_heads = int(num_heads * alpha)
# token dimension in Lo-Fi
self.l_dim = self.l_heads * head_dim
# self-attention heads in Hi-Fi
# 高频特征
self.h_heads = num_heads - self.l_heads
# token dimension in Hi-Fi
self.h_dim = self.h_heads * head_dim
# local window size. The `s` in our paper.
self.ws = window_size
if self.ws == 1:
# 窗口是1的情况等于标准多头自注意力
self.h_heads = 0
self.h_dim = 0
self.l_heads = num_heads
self.l_dim = dim
self.scale = qk_scale or head_dim ** -0.5
# Low frequence attention (Lo-Fi)
if self.l_heads > 0:
# 窗口大小不等于1则根据窗口做平均池化
if self.ws != 1:
self.sr = nn.AvgPool2d(kernel_size=window_size, stride=window_size)
self.l_q = nn.Linear(self.dim, self.l_dim, bias=qkv_bias)
self.l_kv = nn.Linear(self.dim, self.l_dim * 2, bias=qkv_bias)
self.l_proj = nn.Linear(self.l_dim, self.l_dim)
# High frequence attention (Hi-Fi)
if self.h_heads > 0:
self.h_qkv = nn.Linear(self.dim, self.h_dim * 3, bias=qkv_bias)
self.h_proj = nn.Linear(self.h_dim, self.h_dim)
# 高频自注意力
def hifi(self, x):
B, H, W, C = x.shape
h_group, w_group = H // self.ws, W // self.ws
# 窗口数量
total_groups = h_group * w_group
x = x.reshape(B, h_group, self.ws, w_group, self.ws, C).transpose(2, 3)
# 只在窗口内做自注意力
qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.h_dim // self.h_heads).permute(3, 0, 1, 4, 2, 5)
q, k, v = qkv[0], qkv[1], qkv[2] # B, hw, n_head, ws*ws, head_dim
attn = (q @ k.transpose(-2, -1)) * self.scale # B, hw, n_head, ws*ws, ws*ws
attn = attn.softmax(dim=-1)
attn = (attn @ v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, self.h_dim)
x = attn.transpose(2, 3).reshape(B, h_group * self.ws, w_group * self.ws, self.h_dim)
x = self.h_proj(x)
return x
def lofi(self, x):
B, H, W, C = x.shape
# 全局自注意力
q = self.l_q(x).reshape(B, H * W, self.l_heads, self.l_dim // self.l_heads).permute(0, 2, 1, 3)
if self.ws > 1:
x_ = x.permute(0, 3, 1, 2)
# 对kv进行平均池化
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
kv = self.l_kv(x_).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.l_dim)
x = self.l_proj(x)
return x
def forward(self, x):
B, N, C = x.shape
H = W = int(N ** 0.5)
x = x.reshape(B, H, W, C)
if self.h_heads == 0:
x = self.lofi(x)
return x.reshape(B, N, C)
if self.l_heads == 0:
x = self.hifi(x)
return x.reshape(B, N, C)
hifi_out = self.hifi(x)
lofi_out = self.lofi(x)
x = torch.cat((hifi_out, lofi_out), dim=-1)
x = x.reshape(B, N, C)
return x



新哥!
冲冲冲