本文复现了细粒度分类论文《See Better Before Looking Closer》,其提出弱监督数据增强网络,基于注意力图引导裁剪与丢弃进行增强。复现采用InceptionV3骨干,通过双线性池化生成特征矩阵分类。实验在鸟、飞机、汽车数据集上达原论文精度,体现该数据增强策略的有效性与优越性。
☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

本篇论文标题名为See Better Before Looking Closer,这是一篇细粒度分类问题的经典论文,所谓细粒度,就是在一个大类下面对小类进行细分,如对鸟、狗的品种与车、飞机的型号进行分类。对于细粒度分类问题,一般的网络只能较为普通的中等性能,如(VGG、ResNet、Inception),而论文《 See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification》提出一种基于弱监督的数据增强网络,即基于注意力图引导的数据增强策略,也就是不仅仅将原图送入网络训练,将增强后的图片也送入到网络训练,最后loss取平均,该部分思想的示意图如图一所示。上半部分为训练阶段的增强策略,分别为Attention Cropping(基于注意力裁剪)与Attention Dropping(基于注意力丢弃);下半部分为测试验证阶段的增强策略,为基于注意力的裁剪,而后resize到原图尺寸大小送入网络进行预测。
值得注意的是,一般情况我们训练模型时使用的数据增强策略为随机丢弃(遮挡)、随即裁剪等,但是这种随机的方式目的性不强,且容易引入噪声,非常容易裁剪到背景(没有起到增强作用)、或者几乎把主体部分全部裁剪掉了(对于模型收敛有不利的影响),而本文作者提出基于注意力图生成候选区域进行有指向性的裁剪、丢弃,可谓绝佳一笔,随机数据增强方法与基于注意力引导的数据增强方法对比图如图二所示。通俗来说,基于注意图的数据增强策略的思想如下:由于注意力图会注意到图片中主题的一些细节部位,如鸟的嘴部,而通过注意力增强,鸟的嘴部会被裁剪掉,这是便会引导模型更加注重鸟的腹部、羽毛颜色等等其他信息,以此完成数据增强,这也是这篇文章的精髓所在,精度自然显著提升,后文对比试验中将会给出。
本文的模型结构骨干采用InceptionV3网络,利用其中的mix6e层作为特征图,进一步生成注意力图以进行数据增强,同时注意力图与特征图进行双线性池化(BAP)操作生成最终的特征矩阵,而后flatten送入全连接层进行分类,训练阶段的网络结构如图三所示,由于采取了基于注意力引导的数据增强,使得网络更加健壮,即呼应论文标题See Better,看得更好。
测试test阶段网络整体结构大体相似,只是较训练阶段少了一个随机丢弃的数据增强操作,很明显测试阶段我们希望输入网络的图片有更加多的信息,因此也就不需要丢弃,所以这一步送入网络的图片为原图和利用特征图进行目标定位后裁剪并Resize的图像,最终预测结果概率二者取平均,这一步也叫做精修(Refinement)环节,也呼应了论文标题中的See Closer,看的更近。
以上就是本篇论文的核心思想,本项目为基于PaddlePaddle2.2.2的复现
论文: See Better Before Looking Closer: Weakly Supervised Data Augmentation Network for Fine-Grained Visual Classification
参考repo: https://github.com/wvinzh/WS_DAN_PyTorch
论文中采用的数据集均为细粒度分类问题的典型代表,包括鸟、飞机、汽车、狗,每一类数据集下为各自大类下的不同小类图片,本项目复现了前三个数据集,并达到了原论文的精度,具体复现精度如下表所示,数据集可通过下方对应链接下载(运行本项目自带,无需单独下载):
| Dataset | Object | Category | Training | Testing | ACC(复现) | ACC(原论文) |
|---|---|---|---|---|---|---|
| CUB-200-2011 | Bird | 200 | 5994 | 5794 | 89.40 | 89.4 |
| fgvc-aircraft | Aircraft | 100 | 6667 | 3333 | 94.03 | 93.0 |
| Stanford-Cars | Car | 196 | 8144 | 8041 | 94.88 | 94.5 |
| Stanford-Dogs | Dogs | 120 | 12000 | 8580 | (未要求) | 92.2 |
数据集文件夹下的结构如下,解压后在/home/aistudio/work/data文件夹下:
Fine-grained
├── CUB_200_2011
├── images
├── images.txt
├── image_class_labels.txt
├── train_test_split.txt
├── Car
├── cars_test
├── cars_train
├── cars_test_annos_withlabels.mat
├── devkit
├── cars_train_annos.mat
├── fgvc-aircraft-2013b
├── data
├── variants.txt
├── images_variant_trainval.txt
├── images_variant_test.txt为便于清晰展示代码运行流程以及结构,将主要代码都放在了JupyterNotebook中,详情可见第五部分,其余/home/aistudio/work下的代码结构如下所示
/home/aistudio/work
├── datasets # 各种数据集定义读取文件夹
├── __init__.py # 读取数据集函数
├── aircraft_dataset.py # 飞机类数据集定义
├── bird_dataset.py # 鸟类数据集定义
├── car_dataset.py # 车类数据集定义
├── models # 模型相关文件
├── bap.py # BAP模型
├── inception.py # Inceptionv3模型
├── wsdan.py # WS-DAN模型
├── InceptionV3_pretrained.pdparams # Inceptionv3模型权重
├── FGVC # 模型参数保存与训练日志
├── aircraft/ckpt # 飞机类模型参数以及训练日志
├── *.pdparams # 模型网络权重
├── *.log # 训练日志
├── brid/ckpt # 鸟类模型参数以及训练日志
├── *.pdparams # 模型网络权重
├── *.log # 训练日志
├── car/ckpt # 车类模型参数以及训练日志
├── *.pdparams # 模型网络权重
├── *.log # 训练日志
├── imgs # Markdown 图片资源
├── config.py # 超参数设置(您可修改)
├── train.py # 模型训练
└── utils.py # 工具链硬件:
框架:
其他依赖项:
!cd data/ && unzip -oq /home/aistudio/data/data138113/Fine-grained.zip
import sys
sys.path.append('/home/aistudio/work')import osimport loggingimport configimport paddlefrom datasets import getDatasetfrom paddle.io import DataLoaderfrom models.wsdan import WSDANfrom utils import CenterLoss, AverageMeter, TopKAccuracyMetric, batch_augment# 修改您想要测试的数据集config.target_dataset = 'bird' # it can be 'car', 'bird', 'aircraft'# logging configlogging.basicConfig(
filename=os.path.join('/home/aistudio/work/FGVC/' + config.target_dataset + '/ckpt/', 'test.log'),
filemode='w', format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s',
level=logging.INFO)
logging.info('Current Testing Model: {}'.format(config.target_dataset))# read the datasettrain_dataset, val_dataset = getDataset(config.target_dataset, config.input_size)
train_loader, val_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.workers), DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.workers)# output the dataset infologging.info('Dataset Name:{dataset_name}, Val:[{val_num}]'.format(dataset_name=config.target_dataset, train_num=len(train_dataset), val_num=len(val_dataset)))
logging.info('Batch Size:[{0}], Train Batches:[{1}], Val Batches:[{2}]'.format(config.batch_size, len(train_loader), len(val_loader)))# loss and metricloss_container = AverageMeter(name='loss')
raw_metric = TopKAccuracyMetric(topk=(1, 5))
crop_metric = TopKAccuracyMetric(topk=(1, 5))
drop_metric = TopKAccuracyMetric(topk=(1, 5))
num_classes = train_dataset.num_classes# networknet = WSDAN(num_classes=num_classes, num_attentions=config.num_attentions, net_name=config.net_name, pretrained=False)
feature_center = paddle.zeros(shape=[num_classes, config.num_attentions * net.num_features])if config.target_dataset == 'bird':
net_state_dict = paddle.load("work/FGVC/bird/ckpt/bird_model.pdparams")if config.target_dataset == 'aircraft':
net_state_dict = paddle.load("work/FGVC/aircraft/ckpt/aircraft_model.pdparams")if config.target_dataset == 'car':
net_state_dict = paddle.load("work/FGVC/car/ckpt/car_model.pdparams")
net.set_state_dict(net_state_dict)
net.eval()# loss functioncross_entropy_loss = paddle.nn.CrossEntropyLoss()
center_loss = CenterLoss()
logs = {}for i, (X, y) in enumerate(val_loader): # Raw Image
y_pred_raw, _, attention_map = net(X) # Object Localization and Refinement
crop_images = batch_augment(X, attention_map, mode='crop', theta=0.1, padding_ratio=0.05)
y_pred_crop, _, _ = net(crop_images) # Final prediction
y_pred = (y_pred_raw + y_pred_crop) / 2.
# loss
batch_loss = cross_entropy_loss(y_pred, y)
epoch_loss = loss_container(batch_loss.item()) # metrics: top-1,5 error
epoch_acc = raw_metric(y_pred, y)
logs['val_{}'.format(loss_container.name)] = epoch_loss
logs['val_{}'.format(raw_metric.name)] = epoch_acc
batch_info = 'Val Loss {:.4f}, Val Acc ({:.2f}, {:.2f})'.format(epoch_loss, epoch_acc[0], epoch_acc[1])
logging.info(batch_info)print(batch_info)W0524 21:08:29.075071 6888 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1 W0524 21:08:29.079933 6888 device_context.cc:465] device: 0, cuDNN Version: 7.6. /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations if data.dtype == np.object: /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations if data.dtype == np.object: /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations if data.dtype == np.object: /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:130: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations if data.dtype == np.object:
Val Loss 0.4921, Val Acc (88.83, 97.57)
import sys
sys.path.append('/home/aistudio/work')import osimport timeimport loggingfrom tqdm import tqdmimport configimport paddlefrom datasets import getDatasetfrom paddle.io import DataLoaderfrom models.wsdan import WSDANfrom utils import CenterLoss, AverageMeter, TopKAccuracyMetric, batch_augmentimport paddle.nn.functional as Fimport datetime# 若日志保存路径不存在,则新建该文件夹if not os.path.exists(config.save_dir):
os.makedirs(config.save_dir)# 日志格式以及名称配置current = datetime.datetime.now()
log_name = str(config.target_dataset) + '-0' + str(current.month) + '-' + str(current.day) + '-' + str(current.hour) + '_' + str(current.minute) + ".log"logging.basicConfig(filename=os.path.join(config.save_dir, log_name), filemode='w', format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO)# 当前训练模型信息logging.info('Current Trainning Model: {}'.format(config.target_dataset))# 数据集读取train_dataset, val_dataset = getDataset(config.target_dataset, config.input_size)
train_loader, val_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.workers), DataLoader(val_dataset, batch_size=config.batch_size * 4, shuffle=False, num_workers=config.workers)
num_classes = train_dataset.num_classes# 打印当前数据集信息到日志logging.info('Dataset Name:{dataset_name}, Train:[{train_num}], Val:[{val_num}]'.format(dataset_name=config.target_dataset, train_num=len(train_dataset), val_num=len(val_dataset)))
logging.info('Batch Size:[{0}], Train Batches:[{1}], Val Batches:[{2}]'.format(config.batch_size, len(train_loader), len(val_loader)))# loss and metricloss_container = AverageMeter(name='loss')
raw_metric = TopKAccuracyMetric(topk=(1, 5))
crop_metric = TopKAccuracyMetric(topk=(1, 5))
drop_metric = TopKAccuracyMetric(topk=(1, 5))
logs = {}if config.ckpt:
pretrained = Falseelse:
pretrained = Truenet = WSDAN(num_classes=num_classes, num_attentions=config.num_attentions, net_name=config.net_name, pretrained=pretrained)
feature_center = paddle.zeros(shape=[num_classes, config.num_attentions * net.num_features])# 优化器定义scheduler = paddle.optimizer.lr.StepDecay(learning_rate=config.learning_rate, step_size=2, gamma=0.9)
optimizer = paddle.optimizer.Momentum(learning_rate=scheduler, momentum=0.9, weight_decay=1e-5, parameters=net.parameters())# 加载训练好的模型以及优化器参数if config.ckpt:
net_state_dict = paddle.load(config.save_dir + config.target_dataset + "_model.pdparams")
optim_state_dict = paddle.load(config.save_dir + config.target_dataset + "_model.pdopt")
net.set_state_dict(net_state_dict)
optimizer.set_state_dict(optim_state_dict)# 损失函数cross_entropy_loss = paddle.nn.CrossEntropyLoss()
center_loss = CenterLoss()W0524 21:14:33.915575 7480 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1 W0524 21:14:33.920658 7480 device_context.cc:465] device: 0, cuDNN Version: 7.6.
if config.ckpt:
start_epoch = config.model_numelse:
start_epoch = 0max_val_acc = 0 # 最好的精度# 训练config.epochs次for epoch in range(start_epoch, start_epoch + config.epochs):
logs['epoch'] = epoch + 1
logs['lr'] = optimizer.get_lr()
logging.info('Epoch {:03d}, lr= {:g}'.format(epoch + 1, optimizer.get_lr())) print("Start epoch %d ==========,lr=%f" % (epoch + 1, optimizer.get_lr()))
pbar = tqdm(total=len(train_loader), unit=' batches')
pbar.set_description('Epoch {}/{}'.format(epoch + 1, config.epochs)) # 指标初始化
loss_container.reset()
raw_metric.reset()
crop_metric.reset()
drop_metric.reset() # 开始训练
start_time = time.time()
net.train()
scheduler.step() for i, (X, y) in enumerate(train_loader):
optimizer.clear_grad()
y_pred_raw, feature_matrix, attention_map = net(X) # Update Feature Center
feature_center_batch = F.normalize(feature_center[y], axis=-1)
feature_center[y] += config.beta * (feature_matrix.detach() - feature_center_batch) # Attention Cropping
with paddle.no_grad():
crop_images = batch_augment(X, attention_map[:, :1, :, :], mode='crop', theta=(0.4, 0.6), padding_ratio=0.1) # crop images forward
y_pred_crop, _, _ = net(crop_images) # Attention Dropping
with paddle.no_grad():
drop_images = batch_augment(X, attention_map[:, 1:, :, :], mode='drop', theta=(0.2, 0.5)) # drop images forward
y_pred_drop, _, _ = net(drop_images) # loss
batch_loss = cross_entropy_loss(y_pred_raw, y) / 3. + \
cross_entropy_loss(y_pred_crop, y) / 3. + \
cross_entropy_loss(y_pred_drop, y) / 3. + \
center_loss(feature_matrix, feature_center_batch) # backward
batch_loss.backward()
optimizer.step() with paddle.no_grad():
epoch_loss = loss_container(batch_loss.item())
epoch_raw_acc = raw_metric(y_pred_raw, y)
epoch_crop_acc = crop_metric(y_pred_crop, y)
epoch_drop_acc = drop_metric(y_pred_drop, y)
batch_info = 'Loss {:.4f}, Raw Acc ({:.2f}, {:.2f}), Crop Acc ({:.2f}, {:.2f}), Drop Acc ({:.2f}, {:.2f})'.format(
epoch_loss, epoch_raw_acc[0], epoch_raw_acc[1],
epoch_crop_acc[0], epoch_crop_acc[1], epoch_drop_acc[0], epoch_drop_acc[1])
pbar.update()
pbar.set_postfix_str(batch_info) # end of this epoch
logs['train_{}'.format(loss_container.name)] = epoch_loss
logs['train_raw_{}'.format(raw_metric.name)] = epoch_raw_acc
logs['train_crop_{}'.format(crop_metric.name)] = epoch_crop_acc
logs['train_drop_{}'.format(drop_metric.name)] = epoch_drop_acc
logs['train_info'] = batch_info
end_time = time.time() # write log for this epoch
logging.info('Train: {}, Time {:3.2f}'.format(batch_info, end_time - start_time)) # 开始验证,每训练完一轮验证一次精度
net.eval()
loss_container.reset()
raw_metric.reset()
start_time = time.time()
net.eval() with paddle.no_grad(): for i, (X, y) in enumerate(val_loader): # Raw Image
y_pred_raw, _, attention_map = net(X) # Object Localization and Refinement
crop_images = batch_augment(X, attention_map, mode='crop', theta=0.1, padding_ratio=0.05)
y_pred_crop, _, _ = net(crop_images) # Final prediction
y_pred = (y_pred_raw + y_pred_crop) / 2.
# loss
batch_loss = cross_entropy_loss(y_pred, y)
epoch_loss = loss_container(batch_loss.item()) # metrics: top-1,5 error
epoch_acc = raw_metric(y_pred, y)
logs['val_{}'.format(loss_container.name)] = epoch_loss
logs['val_{}'.format(raw_metric.name)] = epoch_acc
end_time = time.time()
batch_info = 'Val Loss {:.4f}, Val Acc ({:.2f}, {:.2f})'.format(epoch_loss, epoch_acc[0], epoch_acc[1])
pbar.set_postfix_str('{}, {}'.format(logs['train_info'], batch_info)) # write log for this epoch
logging.info('Valid: {}, Time {:3.2f}'.format(batch_info, end_time - start_time))
logging.info('')
net.train()
pbar.close() # 模型保存,保存精度最高的模型
if(epoch_acc[0] > max_val_acc):
max_val_acc = epoch_acc[0]
paddle.save(net.state_dict(), config.save_dir + config.target_dataset + "_model.pdparams")
paddle.save(optimizer.state_dict(), config.save_dir + config.target_dataset + "_model.pdopt")以下几点是我在论文复现过程中的一些步骤(仅供参考)
本文的最大精髓之处就在于数据增强,可以看到在鸟类数据集上,采用InceptionV3的baseline,没有数据增强时,准确率为86.4%,在加入随机裁剪、丢弃后略有提高,仅有0.2-0.3%的提升,而加入基于注意力机制的随即裁剪、丢弃,准确率能够提高1%以上,若将注意力裁剪与注意力丢弃进行融合使用,准确率更是达到了88.4%,较baseline提升2%,可见每一个模块都有显著的成效。
| Data Augmentation | Accuracy(%) |
|---|---|
| Baseline | 86.4 |
| Random Cropping | 86.8 |
| Attention Cropping | 87.8 |
| Random Dropping | 86.7 |
| Attention Dropping | 87.4 |
| Attention Cropping + Attention Dropping | 88.4 |
本次论文复现赛完成了一篇细粒度分类问题的论文复现,该文章原理较为容易理解,其中的弱监督思想也极为巧妙,具有启发性,在复现的过程中,主要工作为原代码Pytorch框架到PaddlePaddle框架的转换,在复现过程中我们可以对照着Pytorch的API文档与PaddlePaddle的API文档进行差异对比,也可以查阅飞桨官方提供的Pytorch2.8与Paddle2.0版本API映射表以便进行快速查阅。
此外,由于论文复现赛中需要模型动转静,而某些API不支持静态图,因此我们需要利用其他API进行替换,如paddle.einsumAPI不支持静态图,则可以用paddle.matmul,paddle.transposeAPI进行替换改写。
最后,对于本论文的数据增强思想,对于一些分类问题非常适用,具有一定启发意义。
在上述训练完成后,模型和相关日志保存在work/FGVC/对应数据集类别名称/ckpt文件夹下,可以进行查看
| 信息 | 说明 |
|---|---|
| 发布者 | Victory8858 |
| 时间 | 2022.05 |
| 框架版本 | Paddle 2.2.2 |
| 应用场景 | 细粒度分类问题 |
| 支持硬件 | GPU、CPU |
| Aistudio地址 | https://aistudio.baidu.com/aistudio/projectdetail/3809770?shared=1 |
以上就是论文复现:WS-DAN细粒度分类问题经典之作的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号