论文笔记——Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

论文笔记——Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

3月 16, 2022 阅读 1647 字数 7112 评论 0 喜欢 0

创新点:
①把transformer引入计算机视觉
②把transformer中多头自注意力(MSA)模块替换成基于滑动窗口的模块

滑动窗口机制


红框窗口表示一个自注意力块,灰块在其中计算自注意力。假设窗口长宽为W,则红框窗口在下一层会滑动W/2个位置,然后再次在窗口内计算自注意力。由此可以让不同的窗口内的灰块与其他部分沟通。

整体架构


①把输入图像划分为不重叠的4*4*3大小的块,于是可以把单个块concat起来作为这个块的token。然后让这个token线性投射到C维的特征空间。
②把这H/4*W/4个特征经过一个Swin Transformer Block。随着网络深入,为了获得分级特征,逐步让邻近的2*2个patch合并(concat),对其进行降采样,输出维度为2C。
③Swin Transformer Block的计算:

滑动窗口带来的问题

在滑动窗口操作之后,会给原图增加很多窗口,如图一的红框,为了避免这个问题,可以把小的窗口补全。
其中的masked是为了避免原来不相邻的patch去计算相互的自注意力。(通常会补为一个很小的值,在softmax之后就变成0了)

复杂度

假设一个图像,每个窗口包含M*M个块,总共有h*w个块(C为token表示的块的维度),复其杂度为:
复杂度解释:https://blog.csdn.net/weixin_43135178/article/details/120611131

相对位置偏置


https://blog.csdn.net/qq_34914551/article/details/119866975
为了用一个数字来确定某个query和key的相对位置,以此查询相应bias(可学习)
于是构造一个矩阵,第i个qurey和第j个key对应的数字就是他们在bias表中距离坐标。

        # define a parameter table of relative position bias初始化一个bias表
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
        # get pair-wise relative position index for each token inside the window
        # 各生成一个tensor数组[0, 1, 2, ..., self.window_size[0]-1]
        # 假设self.window_size[0]和[1]都为2
        # 生成一个tensor数组为[0, 1]和[0, 1]
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        # 通过meshgrid形成两个矩阵[[0, 0], [1, 1]]和[[0, 1], [0, 1]]
        # 再stack为[[[0, 0], [1, 1]], [[0, 1], [0, 1]]]
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        # flatten为[[0, 0, 1, 1], [0, 1, 0, 1]]
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        # 在下标为2维插入一个维度减去在下标为1维插入一个维度得到的数组
        # [[[0], [0], [1], [1]], [[0], [1], [0], [1]]] 减去 [[[0, 0, 1, 1]], [[0, 1, 0, 1]]]
        # 等于
        # [[[ 0,  0, -1, -1],
        #  [ 0,  0, -1, -1],
        #  [ 1,  1,  0,  0],
        #  [ 1,  1,  0,  0]],

        # [[ 0, -1,  0, -1],
        #  [ 1,  0,  1,  0],
        #  [ 0, -1,  0, -1],
        #  [ 1,  0,  1,  0]]]
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        # [[[ 0,  0], [ 0, -1], [-1,  0], [-1, -1]],
        # [[ 0,  1], [ 0,  0], [-1,  1], [-1,  0]],
        # [[ 1,  0], [ 1, -1], [ 0,  0], [ 0, -1]],
        # [[ 1,  1], [ 1,  0], [ 0,  1], [ 0,  0]]]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        # 加上偏移量,把负数变成正数
        # [[[1, 1], [1, 0], [0, 1], [0, 0]],
        # [[1, 2],[1, 1], [0, 2], [0, 1]],
        # [[2, 1], [2, 0], [1, 1], [1, 0]],
        # [[2, 2], [2, 1], [1, 2], [1, 1]]]
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        # 乘以一个值进行区分扩大坐标的分布
        # [[[3, 1], [3, 0], [0, 1], [0, 0]],
        # [[3, 2], [3, 1], [0, 2], [0, 1]],
        # [[6, 1], [6, 0], [3, 1], [3, 0]],
        # [[6, 2], [6, 1], [3, 2], [3, 1]]]
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        # 最后让坐标x,y轴相加
        # [[4, 3, 1, 0],
        # [5, 4, 2, 1],
        # [7, 6, 4, 3],
        # [8, 7, 5, 4]]
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        ...
        # 用截断正态分布初始化bias表
        trunc_normal_(self.relative_position_bias_table, std=.02)

疑问

为什么Swin Transformer能减低复杂度?
答:因为只在窗口内计算自注意力

self-attention为什么要除以根号d(d是q和k的维度)?

答:1、首先要除以一个数,防止输入softmax的值过大,导致偏导数趋近于0;
2、选择根号d是因为可以使得q*k^T的结果满足期望为0,方差为1的分布,类似于归一化

为什么不专门设计一个灰块用于不同窗口沟通?

代码

class SwinTransformer(nn.Module):
        ...
        # split image into non-overlapping patches(把输入image划分为patch并把token投射到C维)
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        ...
    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed  # 绝对位置编码
        x = self.pos_drop(x)  # dropout层

        for layer in self.layers:
            x = layer(x)  # SwinTransformer Block

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)  # 展平
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)  # 线性投射到num_classes个分类
        return x
class BasicLayer(nn.Module):
        ...
        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])
        ...
    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)  # 下采样让分辨率长宽各除以2,dim乘2
        return x
class SwinTransformerBlock(nn.Module):
        ...
        # 计算窗口间的自注意力
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # Block最后的MLP层
        if self.shift_size > 0:  # 滑动窗口
            # calculate attention mask for SW-MSA(计算滑动窗口的mask区域)
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))  # 用-100作为mask的值
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)  # 注册为常量,不更新参数
    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)  # 由embeding转为图片块

        # cyclic shift
        if self.shift_size > 0:  # 滑动窗口让图片roll window_size // 2个位置
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows(划分窗口)
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA(计算普通窗口和滑动窗口的自注意力)
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows(恢复窗口格式)
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift(反roll操作)
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

Patch Merging(长宽各缩小两倍,Chanel为原来的两倍)

该模块的作用是在每个Stage开始前做降采样,用于缩小分辨率,调整通道数

最后加一全连接层调整通道数

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

发表评论

您的电子邮箱地址不会被公开。