0

0

【论文复现】CSRA-Paddle: 残差注意力机制模型

P粉084495128

P粉084495128

发布时间:2025-07-29 10:13:46

|

835人浏览过

|

来源于php中文网

原创

本文介绍基于PaddlePaddle复现ICCV 2021论文的CSRA-Paddle项目。该项目通过类特定残余注意力模块(CSRA),结合类别无关平均池化特征与类特定空间注意力特征,提升多标签识别效果。在Pascal VOC 2007数据集上,Resnet101+CSRA模型复现精度达94.7 mAP,提供了完整的数据集准备、训练、验证及推理流程。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

【论文复现】csra-paddle: 残差注意力机制模型 - php中文网

CSRA-Paddle: 残差注意力机制模型

1.1 简介

本项目基于PaddlePaddle 复现了ICCV 2021 上发表的论文:
Residual Attention: A Simple But Effective Method for Multi-Label Recoginition
【论文复现】CSRA-Paddle: 残差注意力机制模型 - php中文网        

为了有效地捕捉来自不同类别的对象所占据的不同空间区域,这篇文章提出了一个非常简单的模块,称为类特定的残余注意力(CSRA)。 CSRA通过提出一个简单的空间注意力分数为每个类别生成特定于类的特征,然后将其与与类别无关的平均池化特征相结合。CSRA 在多标签识别上取得了 state-of-the-art 的结果,同时相比于其他方法简单得多。

本项目基于PaddlePaddle框架复现了CSRA,并在Pascal VOC数据集上进行了实验。

论文:

  • [1] Zhu, K. , and J. Wu . Residual Attention: A Simple But Effective Method for Multi-Label Recoginition. ICCV, 2021.

项目参考:

  • https://github.com/Kevinz-code/CSRA

上述CSRA的核心代码块:

class CSRA(nn.Layer): # one basic block 
    def __init__(self, input_dim, num_classes, T, lam):
        super(CSRA, self).__init__()
        self.T = T      # temperature       
        self.lam = lam  # Lambda                        
        self.head = nn.Conv2D(input_dim, num_classes, 1, bias_attr=False)
        self.softmax = nn.Softmax(axis=2)    def forward(self, x):
        # x (B d H W)
        # normalize classifier
        # score (B C HxW)
        score = self.head(x) / paddle.norm(self.head.weight, axis=1, keepdim=True).transpose((1, 0, 2, 3))
        score = score.flatten(2)
        base_logit = paddle.mean(score, axis=2)        if self.T == 99: # max-pooling
            att_logit = paddle.max(score, axis=2)[0]        else:
            score_soft = self.softmax(score * self.T)
            att_logit = paddle.sum(score * score_soft, axis=2)        return base_logit + self.lam * att_logit

       

可以参阅论文进行理解。

1.2 复现精度

原文在Pascal VOC 2007 val数据集的测试效果如下表

【论文复现】CSRA-Paddle: 残差注意力机制模型 - php中文网        

本项目在Pascal VOC 2007 val数据集的测试效果如下表。

Frame NetWork epochs opt lr resolution batch_size dataset card mAP
本项目Paddle Resnet101+CSRA 30 SGD 0.01 448x448 16 VOC2007 1xV100 94.7

可见,本项目成功用PaddlePaddle复现了论文结果(Resnet101+CSRA: 94.7)。

1.3 数据集

数据集网站:Pascal VOC

AiStudio上的数据集:pascal-voc

数据集介绍:

Pascal 的全称是 Pattern Analysis, Statical Modeling and Computational Learning。 PASCAL VOC 挑战赛是视觉对象的分类识别和检测的一个基准测试,提供了检测算法和学习性能的标准图像注释数据集和标准的评估系统。从2005年至今,该组织每年都会提供一系列类别的、带标签的图片,挑战者通过设计各种精妙的算法,仅根据分析图片内容来将其分类,最终通过准确率、召回率、效率来一决高下。

Pascal VOC(2005~2012)竞赛的目标主要是进行图像的目标识别,其提供的数据集包含20类的物体。每张图片都有标注,标注的物体包括人、动物(如猫、狗、岛等)、交通工具(如车、船飞机等)、家具(如椅子、桌子、沙发等)在内的20个类别。每个图像平均有2.4个目标。

VOC2007:中包含9963张标注过的图片, 由train/val/test三部分组成, 共标注出24,640个物体。

  • 本项目使用的数据集结构:
PATH/Dataset/
|-- VOCdevkit/|---- VOC2007/|------ JPEGImages/|------ Annotations/|------ ImageSets/

       

注:PATH/Dataset/为数据集的路径

DeepAI
DeepAI

为天生具有创造力的人提供的AI工具

下载

快速开始

2.1 数据准备

In [ ]
!unzip -q data/data4379/pascalvoc.zip -d data/data4379/
   
In [1]
%cd /home/aistudio/CSRA-Paddle/
!python utils/prepare/prepare_voc.py  --data_path  /home/aistudio/data/data4379/pascalvoc/VOCdevkit
       
/home/aistudio/CSRA-Paddle
generating labels for VOC07 dataset
generating final json file for VOC07 dataset
VOC07 data preparing finished!
data/voc07/trainval_voc07.json data/voc07/test_voc07.json
       

2.2 训练

In [ ]
%cd /home/aistudio/CSRA-Paddle/
!python train.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20 --save_dir=./checkpoint
   

2.3 验证

In [ ]
%cd /home/aistudio/CSRA-Paddle/
!python val.py --model resnet101 --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20  --load_from output/epoch_11.pdparams
   

结果:

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 310/310 [01:13

mAP: 0.946971

CP: 0.922363, CR: 0.876188, CF1 :0.898682

OP: 0.943647, OR: 0.890632, OF1 0.916373

2.4 预测

In [3]
%cd /home/aistudio/CSRA-Paddle/
!python predict.py --model resnet101 --num_heads 1 --lam 0.1 --dataset voc07 --load_from  output/epoch_11.pdparams --img_dir utils/demo_images
       
backbone params inited by paddle official model
W0410 16:12:18.782222  3012 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0410 16:12:18.786772  3012 device_context.cc:465] device: 0, cuDNN Version: 7.6.
Loading weights from checkpoint_94.697/epoch_11.pdparams
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py:253: UserWarning: The dtype of left and right variables are not the same, left dtype is paddle.float32, but right dtype is paddle.int64, the right dtype will convert to paddle.float32
  format(lhs_dtype, rhs_dtype, lhs_dtype))
utils/demo_images/000002.jpg prediction: train,
utils/demo_images/000007.jpg prediction: car,
utils/demo_images/000004.jpg prediction: car,
utils/demo_images/000009.jpg prediction: horse,person,
utils/demo_images/000001.jpg prediction: dog,person,
utils/demo_images/000006.jpg prediction: chair,
       

2.5 TIPC

注意:本部分为论文复现赛内容,只是为了验证整个项目的训练推理的正确性。学习目的可以不进行这部分的运行,即这部分非项目必要部分。

首先安装auto_log,需要进行安装,安装方式如下: auto_log的详细介绍参考https://github.com/LDOUBLEV/AutoLog。

git clone https://github.com/LDOUBLEV/AutoLog
cd AutoLog/
pip3 install -r requirements.txt
python3 setup.py bdist_wheel
pip3 install ./dist/auto_log-1.2.0-py3-none-any.whl
       

进行TIPC:在命令行执行

bash test_tipc/prepare.sh test_tipc/configs/CSRARes101/train_infer_python.txt 'lite_train_lite_infer'bash test_tipc/test_train_inference_python.sh test_tipc/configs/CSRARes101/train_infer_python.txt 'lite_train_lite_infer'
       

注意:由于代码中每次训练需要生成数据集的标签json文件,进行tipc会覆盖原来data目录下的json文件,所以进行tipc后要进行完整训练的话。需要重新为完整数据集生成json文件,也就是重新执行数据准备的步骤

2.6 模型导出与推理

In [ ]
!python export_model.py --model resnet101 --num_heads 1 --lam 0.1 --img_size=448 --model_path=./output/epoch_11.pdparams --save_dir=./output
   
In [3]
!python infer.py --use_gpu=True --model_file=output/model.pdmodel --input_file=utils/demo_images --params_file=output/model.pdiparams
       
Inference model(CSRARes101)...
W0410 20:56:50.359391 12322 analysis_predictor.cc:795] The one-time configuration of analysis predictor failed, which may be due to native predictor called first and its configurations taken effect.--- Running analysis [ir_graph_build_pass]--- Running analysis [ir_graph_clean_pass]--- Running analysis [ir_analysis_pass]--- Running IR pass [is_test_pass]--- Running IR pass [simplify_with_basic_ops_pass]--- Running IR pass [conv_affine_channel_fuse_pass]--- Running IR pass [conv_eltwiseadd_affine_channel_fuse_pass]--- Running IR pass [conv_bn_fuse_pass]I0410 20:56:50.920820 12322 fuse_pass_base.cc:57] ---  detected 104 subgraphs--- Running IR pass [conv_eltwiseadd_bn_fuse_pass]--- Running IR pass [embedding_eltwise_layernorm_fuse_pass]--- Running IR pass [multihead_matmul_fuse_pass_v2]--- Running IR pass [squeeze2_matmul_fuse_pass]--- Running IR pass [reshape2_matmul_fuse_pass]--- Running IR pass [flatten2_matmul_fuse_pass]--- Running IR pass [map_matmul_v2_to_mul_pass]--- Running IR pass [map_matmul_v2_to_matmul_pass]--- Running IR pass [map_matmul_to_mul_pass]--- Running IR pass [fc_fuse_pass]--- Running IR pass [fc_elementwise_layernorm_fuse_pass]--- Running IR pass [conv_elementwise_add_act_fuse_pass]--- Running IR pass [conv_elementwise_add2_act_fuse_pass]--- Running IR pass [conv_elementwise_add_fuse_pass]--- Running IR pass [transpose_flatten_concat_fuse_pass]--- Running IR pass [runtime_context_cache_pass]--- Running analysis [ir_params_sync_among_devices_pass]I0410 20:56:51.119207 12322 ir_params_sync_among_devices_pass.cc:45] Sync params from CPU to GPU--- Running analysis [adjust_cudnn_workspace_size_pass]--- Running analysis [inference_op_replace_pass]--- Running analysis [memory_optimize_pass]I0410 20:56:52.790841 12322 memory_optimize_pass.cc:216] Cluster name : relu_18.tmp_0  size: 6422528
I0410 20:56:52.790884 12322 memory_optimize_pass.cc:216] Cluster name : x  size: 2408448
I0410 20:56:52.790887 12322 memory_optimize_pass.cc:216] Cluster name : tmp_2  size: 12845056
I0410 20:56:52.790899 12322 memory_optimize_pass.cc:216] Cluster name : relu_3.tmp_0  size: 12845056
I0410 20:56:52.790905 12322 memory_optimize_pass.cc:216] Cluster name : relu_9.tmp_0  size: 12845056--- Running analysis [ir_graph_to_program_pass]I0410 20:56:52.913156 12322 analysis_predictor.cc:714] ======= optimize end =======
I0410 20:56:52.924579 12322 naive_executor.cc:98] ---  skip [feed], feed -> x
I0410 20:56:52.928333 12322 naive_executor.cc:98] ---  skip [tmp_38], fetch -> fetch
W0410 20:56:52.950525 12322 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0410 20:56:52.954545 12322 device_context.cc:465] device: 0, cuDNN Version: 7.6.
utils/demo_images/000002.jpg	prediction: 
train,
utils/demo_images/000007.jpg	prediction: 
car,
utils/demo_images/000004.jpg	prediction: 
car,
utils/demo_images/000009.jpg	prediction: 
horse,person,
utils/demo_images/000001.jpg	prediction: 
dog,person,
utils/demo_images/000006.jpg	prediction: 
chair,
       

导出的模型推理结果与动态图预测结果一致。

复现心得与相关信息

复现心得

多标签图像识别是一项具有挑战性的实用计算机视觉任务。然而,该领域的进展往往具有方法复杂、计算量大、缺乏直观解释的特点。而这篇论文则从很简单的结构设计出发,仅用几行代码,在许多不同的预训练模型和数据集上实现一致的改进,而无需任何额外的训练。CSRA 既易于实现又易于计算,还具有直观的解释。

非常值得读者在图像分类方面的进阶学习!

本次复现也是我在图像分类领域的第一次复现,同时也是第一次完成TIPC任务,学习到了TIPC的内涵,可以帮助别人更快的验证你的模型。

复现的经验分享可以从两个方面来讲:第一步是熟悉论文的核心思想和参考代码的基本结构和核心代码,对复现的难度等有一个大概的把握。第二个是快速的代码对齐。这部分主要是需要熟悉不同框架与Paddle的api函数的功能,不熟悉也没关系,可以通过查阅官网的手册和利用X2Paddle提供的对齐文档进行快速上对齐。

相关信息

信息 描述
作者 xbchen
日期 2022年4月
框架版本 PaddlePaddle==2.2.1
应用场景 图像分类
硬件支持 GPU、CPU

相关专题

更多
json数据格式
json数据格式

JSON是一种轻量级的数据交换格式。本专题为大家带来json数据格式相关文章,帮助大家解决问题。

406

2023.08.07

json是什么
json是什么

JSON是一种轻量级的数据交换格式,具有简洁、易读、跨平台和语言的特点,JSON数据是通过键值对的方式进行组织,其中键是字符串,值可以是字符串、数值、布尔值、数组、对象或者null,在Web开发、数据交换和配置文件等方面得到广泛应用。本专题为大家提供json相关的文章、下载、课程内容,供大家免费下载体验。

531

2023.08.23

jquery怎么操作json
jquery怎么操作json

操作的方法有:1、“$.parseJSON(jsonString)”2、“$.getJSON(url, data, success)”;3、“$.each(obj, callback)”;4、“$.ajax()”。更多jquery怎么操作json的详细内容,可以访问本专题下面的文章。

309

2023.10.13

go语言处理json数据方法
go语言处理json数据方法

本专题整合了go语言中处理json数据方法,阅读专题下面的文章了解更多详细内容。

74

2025.09.10

golang map内存释放
golang map内存释放

本专题整合了golang map内存相关教程,阅读专题下面的文章了解更多相关内容。

73

2025.09.05

golang map相关教程
golang map相关教程

本专题整合了golang map相关教程,阅读专题下面的文章了解更多详细内容。

28

2025.11.16

golang map原理
golang map原理

本专题整合了golang map相关内容,阅读专题下面的文章了解更多详细内容。

57

2025.11.17

java判断map相关教程
java判断map相关教程

本专题整合了java判断map相关教程,阅读专题下面的文章了解更多详细内容。

34

2025.11.27

java学习网站推荐汇总
java学习网站推荐汇总

本专题整合了java学习网站相关内容,阅读专题下面的文章了解更多详细内容。

6

2026.01.08

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 0.6万人学习

Django 教程
Django 教程

共28课时 | 2.9万人学习

SciPy 教程
SciPy 教程

共10课时 | 1.1万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号