RAS-pytorch代码详解以及到mindspore的迁移

RAS-pytorch代码详解以及到mindspore的迁移

3月 1, 2022 阅读 1568 字数 1261 评论 0 喜欢 0

RAS即“Reverse Attention-Based Residual Network for Salient Object Detection”,是一个图像显著性检测的算法,具有训练速度快,参数量少的优点。

github源代码:ShuhanChen/RAS-pytorch: Pytorch code for our TIP20 paper: Reverse Attention Based Residual Network for Salient Object Detection.. (github.com)

论文:Reverse Attention-Based Residual Network for Salient Object Detection

pytorch代码目录结构:

├── v1
  ├── data.py      # 数据读取
  ├── RAS.py       # RAS网络主结构
  ├── ResNet50.py  # ResNet50网络
  ├── test.py      # 网络推理
  └── train.py     # 网络训练

各个类的作用:

  • data.py文件

class SalObjDataset:

显著性检测数据集读取,对图像数据进行random_flip、random_rotate操作以及Normalize,最后裁剪成352*352的图像。

class test_dataset:

与上面类似,只保留resize和Normalize操作。

  • RAS.py文件

class MSCM:

多尺度上下文卷积模块,用于捕捉全局显著性的线索,用这个结构可以在不降低性能的情况下实现更少的参数。

class RA:

自上而下的反向注意力模块,对ResNet50的各层进行处理。利用resnet深层中具有高语义置信度但低分辨率的特性。通过擦除每个侧输出特征的预测区域来让网络补充细节。而当前区域是从更深层上采样获得的。

class RAS:

网络的整体结构,把从resnet得到的5个层进行处理。

  • ResNet50.py文件

与ResNet50结构类似,不同之处在于最后的输出改成了return5个层。

  • train.py文件

网络训练文件,可调的超参有:epoch, lr, trainsize, batchsize, decay_epoch。

loss值的计算:

首先是BCEloss(二进制交叉熵):

然后作者采用的是非常规的IOUloss。

  • test.py文件

用于生成推理结果,以及计算生成图像的FPS。

  • Evaluation文件

原版是用MATLAB写的评估函数,我把其中的F-measure和MAE改写成python版本。

代码:Saliency-Evaluation-Code: 代码修改自https://github.com/jiwei0921/Saliency-Evaluation-Toolbox#saliency-evaluation-toolbox,把其中的F-measure和MAE scores改写为python版本 (gitee.com)

迁移到华为昇腾的mindspore框架:

research/cv/ras · MindSpore/models – 码云 – 开源中国 (gitee.com)

发表评论

您的电子邮箱地址不会被公开。