论文笔记——FSRNet: End-to-End Learning Face Super-Resolution with Facial Priors

论文笔记——FSRNet: End-to-End Learning Face Super-Resolution with Facial Priors

3月 11, 2022 阅读 1545 字数 3010 评论 0 喜欢 0

创新点:
①利用几何先验,即面部地标热图和解析图
②引入对抗性网络(FSRGAN)

总体流程:先经过一个网络来恢复粗糙图像,然后进入两个分支分别是精细的SR编码器和先验信息评估网络。先验信息评估网络提取图像特征然后对landmark和heatmaps进行评估。

网络结构

核心结构(四个网络):
CoarseSRNetwork()
FineSREncoder()
PriorEstimationNetwork()
FineSRDecoder()

代码

  • CoarseSRNetwork

    class CoarseSRNetwork(nn.Module):
    
    def __init__(self):
        super(CoarseSRNetwork, self).__init__()
        self.conv1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
        )
        self.res_blocks = nn.Sequential(*([ResBlock(64)] * 3))
        self.conv2 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=0, bias=False),
            nn.Tanh(),
        )
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.res_blocks(out)
        out = self.conv2(out)
        return out
  • FineSREncoder

    class FineSREncoder(nn.Module):
    
    def __init__(self):
        super(FineSREncoder, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
        )
        self.res_blocks = nn.Sequential(*([ResBlock(64)] * 12))
        self.conv2 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0, bias=False),
            nn.Tanh(),
        )
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.res_blocks(out)
        out = self.conv2(out)
        return out
  • PriorEstimationNetwork

    class PriorEstimationNetwork(nn.Module):
    
    def __init__(self):
        super(PriorEstimationNetwork, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
        )
        self.res_blocks = nn.Sequential(
            Residual(64, 128),
            ResBlock(128),
            ResBlock(128),
        )
        self.hg_blocks = nn.Sequential(
            HourGlassBlock(128, 3),  # 漏斗状的残差卷积网络
            HourGlassBlock(128, 3),
        )
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.res_blocks(out)
        out = self.hg_blocks(out)
        return out
  • FineSRDecoder

    class FineSRDecoder(nn.Module):
    
    def __init__(self):
        super(FineSRDecoder, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(192, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
        )
        self.deconv1 = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
        )
        self.res_blocks = nn.Sequential(*([ResBlock(64)] * 3))
        self.conv2 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=0, bias=False),
            nn.Tanh(),
        )
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.deconv1(out)
        out = self.res_blocks(out)
        out = self.conv2(out)
        return out
  • 主体结构

    class FSRNet(nn.Module):
    
    def __init__(self, hmaps_ch, pmaps_ch):
        ...
    
    def forward(self, x):
        y_c = self.csr_net(x)  # 粗糙网络
        f = self.fsr_enc(y_c)  # 精细的SR编码器
        p = self.pre_net(y_c)  # 先验网络
    
        # 1x1 conv for hmaps & pmaps(构造先验评估的结果)
        b1 = (self.prior_conv1 is not None)
        b2 = (self.prior_conv2 is not None)
        if b1 and b2:
            hmaps = self.prior_conv1(p)
            pmaps = self.prior_conv2(p)
            prs = torch.cat((hmaps, pmaps), 1)
        elif b1:
            prs = self.prior_conv1(p)
        elif b2:
            prs = self.prior_conv2(p)
    
        concat = torch.cat((f, p), 1)  # 合并sr编码器
        out = self.fsr_dec(concat)  # sr解码器
        return y_c, prs, out
  • 损失计算

        loss1 = criterion(y_c, image_hr)
        loss2 = criterion(out, image_hr)
        loss3 = criterion(prs, image_pr)
        loss = loss1 + loss2 + loss3

疑问

为何卷积构成的PriorEstimationNetwork能提取先验信息?
答:因为这个网络的损失是单独计算的,与groundtruth的hmaps和pmaps来比较。

发表评论

您的电子邮箱地址不会被公开。