论文笔记——Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation

论文笔记——Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation

3月 19, 2022 阅读 1950 字数 3282 评论 1 喜欢 0

创新点:
①新的编码器结构,把图像投影到W+空间(与以往的先还原图像,再编辑不同,本方法在W+空间中编辑)。
②证明了图像的W空间,可以提供控制和编辑的能力
③采用了一个预先训练好的StyleGAN来恢复图像


W+空间

18个不同的512维向量,每个StyleGAN的输入层一个。

网络结构


因为网络是把输入投射到18个W+空间,因此天然地支持多模态的学习,如上图。人脸轮廓通过低维度(1-7层)的W+空间层输入,而高维度(8-18层)则接受样本数据的随机输入。

参考StyleGAN的Style mixing,把获得的18个特征向量传入StyleGAN的不同层,代表了从粗糙到精细的不同特征。


效果

除了超分,这个网络还可以拓展到轮廓生成人脸,遮挡的人脸还原,甚至是猫脸狗脸。


CODE

主结构

    def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
                inject_latent=None, return_latents=False, alpha=None):
        if input_code:  # 是否输入的是W+空间向量
            codes = x
        else:
            codes = self.encoder(x)  # 对输入图像进行编码
            # normalize with respect to the center of an average face
            # 加上均值来归一化
            if self.opts.start_from_latent_avg:
                if self.opts.learn_in_w:
                    codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
                else:
                    codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)

        # 是否加入遮挡
        if latent_mask is not None:
            for i in latent_mask:
                if inject_latent is not None:
                    if alpha is not None:
                        codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
                    else:
                        codes[:, i] = inject_latent[:, i]
                else:
                    codes[:, i] = 0

        input_is_latent = not input_code
        # 解码器(也是GAN的生成器)
        images, result_latent = self.decoder([codes],
                                             input_is_latent=input_is_latent,
                                             randomize_noise=randomize_noise,
                                             return_latents=return_latents)

        if resize:
            images = self.face_pool(images)

        if return_latents:
            return images, result_latent
        else:
            return images

编码器结构

    def forward(self, x):
        x = self.input_layer(x)

        latents = []
        modulelist = list(self.body._modules.values())
        # 分别对应粗糙、中等、精细特征的提取
        # 左边三个蓝色块的处理,一直卷积到深层,保留中间的三个值做残差
        for i, l in enumerate(modulelist):
            x = l(x)
            if i == 6:
                c1 = x
            elif i == 20:
                c2 = x
            elif i == 23:
                c3 = x
        # c3中间值分别进入各自的残差卷积块(提取三种不同的小特征)
        for j in range(self.coarse_ind):
            latents.append(self.styles[j](c3))

        # 如图的skip connnect,c3加卷积后的c2
        p2 = self._upsample_add(c3, self.latlayer1(c2))
        # p2中间值分别进入各自的残差卷积块(提取四种不同的中等特征)
        for j in range(self.coarse_ind, self.middle_ind):
            latents.append(self.styles[j](p2))
        # 如图的skip connnect,p2加卷积后的c1
        p1 = self._upsample_add(p2, self.latlayer2(c1))
        # p1中间值分别进入各自的残差卷积块(提取12种的的大特征)
        for j in range(self.middle_ind, self.style_count):
            latents.append(self.styles[j](p1))
        # 把特征合并
        out = torch.stack(latents, dim=1)
        return out

生成器结构(来自StyleGAN)

    def forward(...):
        # 如果输入不是特征空间,则再加入styles个fc层让输入进入特征空间(对应StyleGAN从z到w空间)
        if not input_is_latent:
            styles = [self.style(s) for s in styles]
        # 是否加入噪声(控制输出的人脸细节,如头发卷曲度)
        if noise is None:
            if randomize_noise:
                noise = [None] * self.num_layers
            else:
                noise = [
                    getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
                ]
        # 是否聚拢数据,truncation是聚拢比例,truncation_latent是均值
        if truncation < 1:
            style_t = []

            for style in styles:
                style_t.append(
                    truncation_latent + truncation * (style - truncation_latent)
                )

            styles = style_t

        if len(styles) < 2:
            inject_index = self.n_latent

            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            else:
                latent = styles[0]

        else:
            if inject_index is None:
                inject_index = random.randint(1, self.n_latent - 1)

            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)

            latent = torch.cat([latent, latent2], 1)

        # 参数化
        out = self.input(latent)
        # 第一个SytleGAN层
        out = self.conv1(out, latent[:, 0], noise=noise[0])
        # 还原为RGB
        skip = self.to_rgb1(out, latent[:, 1])

        i = 1
        # 后续的SytleGAN层
        for conv1, conv2, noise1, noise2, to_rgb in zip(
                self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
        ):
            out = conv1(out, latent[:, i], noise=noise1)
            out = conv2(out, latent[:, i + 1], noise=noise2)
            skip = to_rgb(out, latent[:, i + 2], skip)

            i += 2

        image = skip

        if return_latents:
            return image, latent
        elif return_features:
            return image, out
        else:
            return image, None

评论列表

发表评论

您的电子邮箱地址不会被公开。