0

0

基于Paddle2.0的注意力卷积网络CBAM和BAM

P粉084495128

P粉084495128

发布时间:2025-07-22 10:26:51

|

275人浏览过

|

来源于php中文网

原创

想给卷积网络添加注意力机制吗?是否已经厌倦了使用SE-NET?本项目使用Paddle2.0复现了含有注意力机制的卷积网络CBAM和BAM,并在动物分类数据集上进行了训练和验证。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

基于paddle2.0的注意力卷积网络cbam和bam - php中文网

项目背景

CBAM是2018年ECCV上的一篇论文CBAM: Convolutional Block Attention Module中提出的基于注意力机制的卷积网络模型。BAM是2018年BMVC上的一篇论文BAM: Bottleneck Attention Module中提出的基于注意力机制的网络模型。本项目即对其进行复现。

计算机视觉领域的注意力机制主要涵盖空间注意力和通道注意力两个方面。其中空间注意力用来捕获像素间的关系,而通道注意力用来捕获通道间的关系。CBAM提出了Convolutional Block Attention Module(CBAM)模块,该模块从空间注意力和通道注意力两个方面生成注意力特征图,然后将注意力特征图和输入进行相乘来调节注意力特征图的参数。本项目复现CBAM和BAM并用其来完成动物图像分类的实验。

项目简介

本项目首次使用paddle2.0复现了含有注意力机制的网络CBAM和BAM,并在动物数据集上进行了训练和验证。

动物数据集的划分是按8:2的的划分方法进行训练集与验证集划分的。

模型简介

CBAM网络的核心思想是提出了CBAM模块。该模块对输入先经过通道注意力模块,和输入相乘后再经过空间注意力模块,和输入再次相乘后得到调整参数的注意力特征图。如图1所示。

基于Paddle2.0的注意力卷积网络CBAM和BAM - php中文网

图1 CBAM模块细节示意图

BAM网络的核心思想是提出了BAM模块。BAM可以认为是并行版的CBAM。如图2所示。

基于Paddle2.0的注意力卷积网络CBAM和BAM - php中文网

图2 BAM模块细节示意图

具体实现可以fork后见代码细节。

论文原文:CBAM: Convolutional Block Attention Module

美间AI
美间AI

美间AI:让设计更简单

下载

参考代码:

PyTorch的实现

数据集介绍

本项目使用10分类的动物数据集进行训练和测试.

该十分类动物数据集,包含dog,horse,elephant,butterfly,chicken,cat,cow,sheep,spider和squirrel。每一分类的图片数量为2k-5k。

文件结构

文件名或文件夹名 功能
cbam.py CBAM模块定义文件
bam.py BAM模块定义文件
cbam_resnet.py CBAM和BAM网络定义文件
animal_dataset.py 数据集定义文件
config.py 配置文件
train_val_split.py 训练验证划分文件
train.py 模型训练
eval.py 模型验证

解压数据集

In [1]
!unzip -q data/data70196/animals.zip -d work/dataset

查看图片

In [ ]
import osimport randomfrom matplotlib import pyplot as pltfrom PIL import Image

imgs = []
paths = os.listdir('work/dataset')for path in paths:   
    img_path = os.path.join('work/dataset', path)    if os.path.isdir(img_path):
        img_paths = os.listdir(img_path)
        img = Image.open(os.path.join(img_path, random.choice(img_paths)))
        imgs.append((img, path))

f, ax = plt.subplots(3, 3, figsize=(12,12))for i, img in enumerate(imgs[:9]):
    ax[i//3, i%3].imshow(img[0])
    ax[i//3, i%3].axis('off')
    ax[i//3, i%3].set_title('label: %s' % img[1])
plt.show()

划分训练集和验证集

In [2]
!python code/train_val_split.py
finished train val split!

使用CBAM-ResNet50网络进行动物分类的训练并验证

训练

In [1]
!python code/train.py --net 'cbam_resnet'

验证

In [32]
!python code/eval.py --net 'cbam_resnet'
W0218 15:48:02.818117 23045 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W0218 15:48:02.824904 23045 device_context.cc:372] device: 0, cuDNN Version: 7.6.
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 103/103 [==============================] - loss: 0.3478 - acc: 0.8544 - 232ms/step         
Eval samples: 3276
{'loss': [0.347824], 'acc': 0.8543956043956044}

使用BAM-ResNet50网络进行动物分类的训练并验证

训练

In [31]
!python code/train.py --net 'bam_resnet'
W0218 15:48:47.528769 23145 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W0218 15:48:47.532490 23145 device_context.cc:372] device: 0, cuDNN Version: 7.6.
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/50

验证

In [34]
!python code/eval.py --net 'bam_resnet'
W0218 19:49:38.340137  5185 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1
W0218 19:49:38.343930  5185 device_context.cc:372] device: 0, cuDNN Version: 7.6.
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 103/103 [==============================] - loss: 0.2684 - acc: 0.8504 - 199ms/step        
Eval samples: 3276
{'loss': [0.2684111], 'acc': 0.8504273504273504}

图示训练验证过程

基于Paddle2.0的注意力卷积网络CBAM和BAM - php中文网

图3. 使用CBAM和BAM的训练验证图示

使用resnet50网络进行动物分类的训练并验证

训练

In [2]
!python code/train.py --net 'resnet'

验证

In [ ]
!python code/eval.py --net 'resnet'
W0213 21:34:50.038996 12684 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0213 21:34:50.043457 12684 device_context.cc:372] device: 0, cuDNN Version: 7.6.
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 103/103 [==============================] - loss: 1.4232 - acc: 0.5888 - 191ms/step        
Eval samples: 3276
{'loss': [1.4232028], 'acc': 0.5888278388278388}

图示训练验证过程

基于Paddle2.0的注意力卷积网络CBAM和BAM - php中文网

图4. 使用ResNet的训练验证图示

比较

基于Paddle2.0的注意力卷积网络CBAM和BAM - php中文网

图5. 使用CBAM、BAM和ResNet的验证比较图示

相关专题

更多
pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

427

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

5

2025.12.22

ip地址修改教程大全
ip地址修改教程大全

本专题整合了ip地址修改教程大全,阅读下面的文章自行寻找合适的解决教程。

86

2025.12.26

压缩文件加密教程汇总
压缩文件加密教程汇总

本专题整合了压缩文件加密教程,阅读专题下面的文章了解更多详细教程。

50

2025.12.26

wifi无ip分配
wifi无ip分配

本专题整合了wifi无ip分配相关教程,阅读专题下面的文章了解更多详细教程。

102

2025.12.26

漫蛙漫画入口网址
漫蛙漫画入口网址

本专题整合了漫蛙入口网址大全,阅读下面的文章领取更多入口。

297

2025.12.26

b站看视频入口合集
b站看视频入口合集

本专题整合了b站哔哩哔哩相关入口合集,阅读下面的文章查看更多入口。

592

2025.12.26

俄罗斯搜索引擎yandex入口汇总
俄罗斯搜索引擎yandex入口汇总

本专题整合了俄罗斯搜索引擎yandex相关入口合集,阅读下面的文章查看更多入口。

729

2025.12.26

虚拟号码教程汇总
虚拟号码教程汇总

本专题整合了虚拟号码接收验证码相关教程,阅读下面的文章了解更多详细操作。

63

2025.12.25

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 0.6万人学习

Django 教程
Django 教程

共28课时 | 2.5万人学习

SciPy 教程
SciPy 教程

共10课时 | 0.9万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号