中心差卷积

中心差卷积

3月 27, 2022 阅读 2030 字数 1416 评论 1 喜欢 1

首先描述一下普通卷积的公式:

p0代表当前输入输出的位置,pn代表R的所有枚举。
w(pn)代表pn位置的卷积核权重
,x(p0+pn)代表p0偏移了pn位置后取得值,权重和值乘积累加后得到y(p0),也就是p0处卷积后的输出。

中心差卷积


因此,中心差卷积即为采样后,所有值减去中心值,再做卷积的操作,如上图所示。

当pn=(0,0)即为中心值的时候,梯度值相对于中心位置p0本身始终等于零。

结合普通卷积和中心差卷积

将普通卷积与中心差卷积相结合可能是提供更鲁棒建模能力的可行方法。因为强度级语义信息和梯度级详细信息对于生成人脸图都是至关重要的。梯度可以让生成的图片过渡更合理。

简化公式:

CODE

########################   Centeral-difference (second order, with 9 parameters and a const theta for 3x3 kernel) 2D Convolution   ##############################
## | a1 a2 a3 |   | w1 w2 w3 |
## | a4 a5 a6 | * | w4 w5 w6 | --> output = \sum_{i=1}^{9}(ai * wi) - \sum_{i=1}^{9}wi * a5 --> Conv2d (k=3) - Conv2d (k=1)
## | a7 a8 a9 |   | w7 w8 w9 |
##
##   --> output = 
## | a1 a2 a3 |   |  w1  w2  w3 |     
## | a4 a5 a6 | * |  w4  w5  w6 |  -  | a | * | w\_sum |     (kernel_size=1x1, padding=0)
## | a7 a8 a9 |   |  w7  w8  w9 |    
class Conv2d_cd(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 padding=1, dilation=1, groups=1, bias=False, theta=0.7):

        super(Conv2d_cd, self).__init__() 
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.theta = theta

    def forward(self, x):
        out_normal = self.conv(x)

        if math.fabs(self.theta - 0.0) < 1e-8:
            return out_normal 
        else:
            #pdb.set_trace()
            [C_out,C_in, kernel_size,kernel_size] = self.conv.weight.shape
            kernel_diff = self.conv.weight.sum(2).sum(2)
            kernel_diff = kernel_diff[:, :, None, None]
            out_diff = F.conv2d(input=x, weight=kernel_diff, bias=self.conv.bias, stride=self.conv.stride, padding=0, groups=self.conv.groups)

            return out_normal - self.theta * out_diff

评论列表

  1. ku说道:

    shuai

发表评论

您的电子邮箱地址不会被公开。