创新点:同时利用CNN的捕获局部特征的优点和Transformer捕获长距离特征的优点。
上图中的(c)表示整个网络结构的并发构型。
(b)表示,两个分支的初始特征是相同的,沿着两个分支以交互的方式逐步融合特征。最后,CNN分支合并输入给一个分类器,Transformer分支给另一个分类器。
CNN分支采用特征金字塔(深度增加-分辨率降低-通道增加)
Transformer块把不重叠的图像块投影到向量空间(导致局部特征消失)图像块大小为14*14
因为CNN已经包含了局部特征信息和位置信息,所以Transformer不再需要位置编码。
特征耦合单元
CNN的特征维度是CHW
Transformer的特征维度为(K+1)*E其中K是图像块的数量,1是分类token,E是embedding的维度。
- 要让CNN传到Transformer,首先要通过1*1卷积把channel变成E。然后如上图(a)利用平均池化和reshape变成Transformer的特征维度,与Transformer的特征相加。
- 从Transformer到CNN则采用类似的操作,如图(b)。
这一过程重复执行了N次,如图(c)
这一结构可以抽象为CNN和Transformer的残差结构(或者说集成学习),既能看作是CNN连接也能看作是Transformer的连接。
代码解析
ConvTransformer层
def forward(self, x, x_t):
x, x2 = self.cnn_block(x)
_, _, H, W = x2.shape
# CNN的特征转换成Transformer的shape
x_st = self.squeeze_block(x2, x_t)
# 原本的transformer特征加上CNN转的特征
x_t = self.trans_block(x_st + x_t)
if self.num_med_block > 0:
for m in self.med_block:
x = m(x)
# transformer特征转为CNN特征,并且与原来的融合
x_t_r = self.expand_block(x_t, H // self.dw_stride, W // self.dw_stride)
x = self.fusion_block(x, x_t_r, return_x_2=False)
return x, x_t
总体Conformer结构
class Conformer(nn.Module):
def __init__(self, patch_size=16, in_chans=3, num_classes=1000, base_channel=64, channel_ratio=4, num_med_block=0,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
# Transformer
super().__init__()
self.num_classes = num_classes #分类的数量
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
assert depth % 3 == 0
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # class token初始化为1*1*embed_dim的维度,用于最后的分类
self.trans_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
# Classifier head
self.trans_norm = nn.LayerNorm(embed_dim)
self.trans_cls_head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.pooling = nn.AdaptiveAvgPool2d(1)
self.conv_cls_head = nn.Linear(int(256 * channel_ratio), num_classes)
# Stem stage: get the feature maps by conv block (copied form resnet.py)也就是获得CNN的特征
self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False) # 1 / 2 [112, 112]
self.bn1 = nn.BatchNorm2d(64)
self.act1 = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # 1 / 4 [56, 56]
# 1 stage
stage_1_channel = int(base_channel * channel_ratio)
trans_dw_stride = patch_size // 4
self.conv_1 = ConvBlock(inplanes=64, outplanes=stage_1_channel, res_conv=True, stride=1) # 对应图中的1*1 Conv-BN,3*3 Conv-BN,1*1 Conv-BN
self.trans_patch_conv = nn.Conv2d(64, embed_dim, kernel_size=trans_dw_stride, stride=trans_dw_stride, padding=0) # 把chanel64投影到embed_dim维,同时用4*4卷积提取特征得到14*14个特征块(原尺寸为N, 64, 56, 56)
self.trans_1 = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=self.trans_dpr[0],
) # transformer模块,包括多头自注意力和mlp
# 2~4 stage
init_stage = 2
fin_stage = depth // 3 + 1
# 3个ConvTransformer层
for i in range(init_stage, fin_stage):
self.add_module('conv_trans_' + str(i),
ConvTransBlock(
stage_1_channel, stage_1_channel, False, 1, dw_stride=trans_dw_stride, embed_dim=embed_dim,
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=self.trans_dpr[i-1],
num_med_block=num_med_block
)
)
# 金字塔结构channel逐渐增加
stage_2_channel = int(base_channel * channel_ratio * 2)
# 5~8 stage
init_stage = fin_stage # 5
fin_stage = fin_stage + depth // 3 # 9
for i in range(init_stage, fin_stage):
s = 2 if i == init_stage else 1
in_channel = stage_1_channel if i == init_stage else stage_2_channel
res_conv = True if i == init_stage else False
# CNN的分辨率逐渐降低,channel逐渐增加,Transformer不变
self.add_module('conv_trans_' + str(i),
ConvTransBlock(
in_channel, stage_2_channel, res_conv, s, dw_stride=trans_dw_stride // 2, embed_dim=embed_dim,
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=self.trans_dpr[i-1],
num_med_block=num_med_block
)
)
# 变化同上
stage_3_channel = int(base_channel * channel_ratio * 2 * 2)
# 9~12 stage
init_stage = fin_stage # 9
fin_stage = fin_stage + depth // 3 # 13
for i in range(init_stage, fin_stage):
s = 2 if i == init_stage else 1
in_channel = stage_2_channel if i == init_stage else stage_3_channel
res_conv = True if i == init_stage else False
last_fusion = True if i == depth else False
self.add_module('conv_trans_' + str(i),
ConvTransBlock(
in_channel, stage_3_channel, res_conv, s, dw_stride=trans_dw_stride // 4, embed_dim=embed_dim,
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=self.trans_dpr[i-1],
num_med_block=num_med_block, last_fusion=last_fusion
)
)
self.fin_stage = fin_stage
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
...
...
...
def forward(self, x):
B = x.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1)
# pdb.set_trace()
# stem stage [N, 3, 224, 224] -> [N, 64, 56, 56]
x_base = self.maxpool(self.act1(self.bn1(self.conv1(x))))
# 1 stage
x = self.conv_1(x_base, return_x_2=False)
x_t = self.trans_patch_conv(x_base).flatten(2).transpose(1, 2)
x_t = torch.cat([cls_tokens, x_t], dim=1)
x_t = self.trans_1(x_t)
# 2 ~ final
for i in range(2, self.fin_stage):
x, x_t = eval('self.conv_trans_' + str(i))(x, x_t)
# conv classification
x_p = self.pooling(x).flatten(1)
conv_cls = self.conv_cls_head(x_p)
# trans classification
x_t = self.trans_norm(x_t)
tran_cls = self.trans_cls_head(x_t[:, 0])
return [conv_cls, tran_cls]