创新点:
①把transformer引入计算机视觉
②把transformer中多头自注意力(MSA)模块替换成基于滑动窗口的模块
滑动窗口机制

红框窗口表示一个自注意力块,灰块在其中计算自注意力。假设窗口长宽为W,则红框窗口在下一层会滑动W/2个位置,然后再次在窗口内计算自注意力。由此可以让不同的窗口内的灰块与其他部分沟通。
整体架构

①把输入图像划分为不重叠的4*4*3大小的块,于是可以把单个块concat起来作为这个块的token。然后让这个token线性投射到C维的特征空间。
②把这H/4*W/4个特征经过一个Swin Transformer Block。随着网络深入,为了获得分级特征,逐步让邻近的2*2个patch合并(concat),对其进行降采样,输出维度为2C。
③Swin Transformer Block的计算:
滑动窗口带来的问题
在滑动窗口操作之后,会给原图增加很多窗口,如图一的红框,为了避免这个问题,可以把小的窗口补全。
其中的masked是为了避免原来不相邻的patch去计算相互的自注意力。(通常会补为一个很小的值,在softmax之后就变成0了)
复杂度
假设一个图像,每个窗口包含M*M个块,总共有h*w个块(C为token表示的块的维度),复其杂度为:
复杂度解释:https://blog.csdn.net/weixin_43135178/article/details/120611131
相对位置偏置

https://blog.csdn.net/qq_34914551/article/details/119866975
为了用一个数字来确定某个query和key的相对位置,以此查询相应bias(可学习)
于是构造一个矩阵,第i个qurey和第j个key对应的数字就是他们在bias表中距离坐标。
# define a parameter table of relative position bias初始化一个bias表
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
# 各生成一个tensor数组[0, 1, 2, ..., self.window_size[0]-1]
# 假设self.window_size[0]和[1]都为2
# 生成一个tensor数组为[0, 1]和[0, 1]
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
# 通过meshgrid形成两个矩阵[[0, 0], [1, 1]]和[[0, 1], [0, 1]]
# 再stack为[[[0, 0], [1, 1]], [[0, 1], [0, 1]]]
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
# flatten为[[0, 0, 1, 1], [0, 1, 0, 1]]
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
# 在下标为2维插入一个维度减去在下标为1维插入一个维度得到的数组
# [[[0], [0], [1], [1]], [[0], [1], [0], [1]]] 减去 [[[0, 0, 1, 1]], [[0, 1, 0, 1]]]
# 等于
# [[[ 0, 0, -1, -1],
# [ 0, 0, -1, -1],
# [ 1, 1, 0, 0],
# [ 1, 1, 0, 0]],
# [[ 0, -1, 0, -1],
# [ 1, 0, 1, 0],
# [ 0, -1, 0, -1],
# [ 1, 0, 1, 0]]]
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
# [[[ 0, 0], [ 0, -1], [-1, 0], [-1, -1]],
# [[ 0, 1], [ 0, 0], [-1, 1], [-1, 0]],
# [[ 1, 0], [ 1, -1], [ 0, 0], [ 0, -1]],
# [[ 1, 1], [ 1, 0], [ 0, 1], [ 0, 0]]]
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
# 加上偏移量,把负数变成正数
# [[[1, 1], [1, 0], [0, 1], [0, 0]],
# [[1, 2],[1, 1], [0, 2], [0, 1]],
# [[2, 1], [2, 0], [1, 1], [1, 0]],
# [[2, 2], [2, 1], [1, 2], [1, 1]]]
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
# 乘以一个值进行区分扩大坐标的分布
# [[[3, 1], [3, 0], [0, 1], [0, 0]],
# [[3, 2], [3, 1], [0, 2], [0, 1]],
# [[6, 1], [6, 0], [3, 1], [3, 0]],
# [[6, 2], [6, 1], [3, 2], [3, 1]]]
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
# 最后让坐标x,y轴相加
# [[4, 3, 1, 0],
# [5, 4, 2, 1],
# [7, 6, 4, 3],
# [8, 7, 5, 4]]
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
...
# 用截断正态分布初始化bias表
trunc_normal_(self.relative_position_bias_table, std=.02)
疑问
为什么Swin Transformer能减低复杂度?
答:因为只在窗口内计算自注意力
self-attention为什么要除以根号d(d是q和k的维度)?

答:1、首先要除以一个数,防止输入softmax的值过大,导致偏导数趋近于0;
2、选择根号d是因为可以使得q*k^T的结果满足期望为0,方差为1的分布,类似于归一化
为什么不专门设计一个灰块用于不同窗口沟通?
代码
class SwinTransformer(nn.Module):
...
# split image into non-overlapping patches(把输入image划分为patch并把token投射到C维)
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
...
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed # 绝对位置编码
x = self.pos_drop(x) # dropout层
for layer in self.layers:
x = layer(x) # SwinTransformer Block
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1) # 展平
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x) # 线性投射到num_classes个分类
return x
class BasicLayer(nn.Module):
...
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
...
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x) # 下采样让分辨率长宽各除以2,dim乘2
return x
class SwinTransformerBlock(nn.Module):
...
# 计算窗口间的自注意力
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # Block最后的MLP层
if self.shift_size > 0: # 滑动窗口
# calculate attention mask for SW-MSA(计算滑动窗口的mask区域)
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) # 用-100作为mask的值
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask) # 注册为常量,不更新参数
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C) # 由embeding转为图片块
# cyclic shift
if self.shift_size > 0: # 滑动窗口让图片roll window_size // 2个位置
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows(划分窗口)
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA(计算普通窗口和滑动窗口的自注意力)
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows(恢复窗口格式)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift(反roll操作)
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
Patch Merging(长宽各缩小两倍,Chanel为原来的两倍)
该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数

最后加一全连接层调整通道数
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x






