论文笔记——SwinIR: Image Restoration Using Swin Transformer

论文笔记——SwinIR: Image Restoration Using Swin Transformer

3月 16, 2022 阅读 1736 字数 2459 评论 0 喜欢 0

创新点:
①把Transformer引入超分
②SwinIR由浅层特征提取、深度特征提取和高质量的图像重建三部分组成。

网络结构


①浅层特征提取:33卷积层
②深层特征提取采用K个residual Swin Transformer blocks (RSTB)以及一个3
3卷积层
③最后一个残差把恢复后的HQ加上原始的LQ

RSTB模块

由连续的SwinTransformer和残差连接构成。
Swin Transformer layers (STL)
第i个RSTB的特征F计算方式为(j表示第几个Swin Transformer layers):

第i个RSTB的最后一层采用卷积,并且与当前RSTB的输入F(i,0)进行残差连接

解释

浅层特征主要包含低频特征,深层特征用于恢复丢失的高频特征。低频通过长距离skip传输到恢复模块。

损失函数

采用L1范数计算像素损失

代码

class RSTB(nn.Module):
    ...
    # 主要是在原版SwinTransformer的基础上套了一层卷积
    def forward(self, x, x_size):
        return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
class SwinIR(nn.Module):
        ...
        # 划分图像到patch
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # patch还原为图像
        self.patch_unembed = PatchUnEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        ...
    # transformer操作
    def forward_features(self, x):
        x_size = (x.shape[2], x.shape[3])
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x, x_size)

        x = self.norm(x)  # B L C
        x = self.patch_unembed(x, x_size)

        return x

    def forward(self, x):
        H, W = x.shape[2:]
        x = self.check_image_size(x)

        self.mean = self.mean.type_as(x)
        # 归一化
        x = (x - self.mean) * self.img_range

        if self.upsampler == 'pixelshuffle':
            # for classical SR
            # 浅层特征提取
            x = self.conv_first(x)
            # 深层特征提取,以及初始值的残差
            x = self.conv_after_body(self.forward_features(x)) + x
            x = self.conv_before_upsample(x)
            # 基于sub-pixel的上采样来恢复SR图像
            x = self.conv_last(self.upsample(x))
        elif self.upsampler == 'pixelshuffledirect':
            # for lightweight SR
            x = self.conv_first(x)
            x = self.conv_after_body(self.forward_features(x)) + x
            x = self.upsample(x)
        elif self.upsampler == 'nearest+conv':
            # for real-world SR
            x = self.conv_first(x)
            x = self.conv_after_body(self.forward_features(x)) + x
            x = self.conv_before_upsample(x)
            x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
            x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
            x = self.conv_last(self.lrelu(self.conv_hr(x)))
        else:
            # for image denoising and JPEG compression artifact reduction
            x_first = self.conv_first(x)
            res = self.conv_after_body(self.forward_features(x_first)) + x_first
            x = x + self.conv_last(res)

        # 反归一化
        x = x / self.img_range + self.mean

        return x[:, :, :H*self.upscale, :W*self.upscale]

发表评论

您的电子邮箱地址不会被公开。