牛年识牛,复现金字塔网络PyramidNet实现动物分类

P粉084495128
发布: 2025-07-31 11:12:27
原创
205人浏览过
本项目用paddle2.0复现PyramidNet,在10分类动物数据集上训练验证。该网络通道数随深度逐模块增加,用加法方式及Padding适配通道。对比ResNet,其残差模块层排列不同。实验显示,PyramidNet在动物分类上准确率高于ResNet50,参数更少。还复现了其在Cifar10和Cifar100上的训练验证。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

牛年识牛,复现金字塔网络pyramidnet实现动物分类 - php中文网

项目背景

PyramidNet是2017年CVPR上的一篇论文Deep Pyramidal Residual Networks中提出的网络模型。由于很多论文中都出现该网络,故本项目即对其进行复现。

卷积网络在计算机视觉领域可以说是一个‘must have’的网络模块,通常,卷积网络是多层叠加的,卷积网络的通道数量只是在feature map的大小发生减小的时候出现增加的变化。而PyramidNet是在每个卷积模块后都会有通道数量的增加,从而形成通道数量逐模块逐渐增加的金字塔结构。论文作者通过实验发现,这种网络结构设计有助于网络的鲁棒性。本项目就来对金字塔网络一探究竟,用PyramidNet来玩一下牛年识牛的动物图像分类实验。

项目简介

本项目首次使用paddle2.0复现了金字塔网络PyramidNet,并在动物数据集上进行了训练和验证。

动物数据集的划分是按8:2的的划分方法进行训练集与验证集划分的。

模型简介

PyramidNet网络的核心思想是随着网络深度的加深,通道的数量逐渐增加。PyramidNet网络和ResNet网络的比较如图1所示。

牛年识牛,复现金字塔网络PyramidNet实现动物分类 - php中文网

图1 PyramidNet网络和ResNet网络的结构比较图

既然PyramidNet网络的结构是随着网络深度的加深,通道数量逐层增加,那么增加的方式就有加法和乘法两种方式。图2给出了这两种方式及它们比较的图示。其中a是加法方式,b是乘法方式,c是两种方式的比较。可以看出,b这种方式更接近传统的ResNet网络。论文中通过实验得出加法方式比乘法方式效果要好,故本项目采用加法的通道数量增加方法。

牛年识牛,复现金字塔网络PyramidNet实现动物分类 - php中文网

图2 PyramidNet以加法和乘法方式增加通道数量的比较图示

模型细节

由于PyramidNet网络的通道数量是逐模块增加的,它的残差模块在做残差相加的时候就要对通道数量做适配。PyramidNet采取Padding的方式对通道数量做适配,实现细节如图3所示。其中a是通过padding做通道数量适配的示意图,b是通过padding做通道数量适配的等效分解示意图。

牛年识牛,复现金字塔网络PyramidNet实现动物分类 - php中文网

图3 PyramidNet残差模块通道数量适配示意图

PyramidNet网络的另一个实现细节是对残差模块的卷积层、BatchNorm层和ReLU的排列组合进行了有别于ResNet的设计,具体实现可以fork后见代码细节。

论文原文:Deep Pyramidal Residual Networks

参考代码:

LuaTorch的实现

Caffe的实现

牛面
牛面

牛面AI面试,大厂级面试特训平台

牛面147
查看详情 牛面

PyTorch的实现

数据集介绍

本项目使用10分类的动物数据集进行训练和验证.

该十分类动物数据集,包含dog,horse,elephant,butterfly,chicken,cat,cow,sheep,spider和squirrel。每一分类的图片数量为2k-5k。

文件结构

文件名或文件夹名 功能
PyramidNet.py 金字塔网络模型文件
animal_dataset.py 数据集定义文件
config.py 配置文件
train_val_split.py 训练集验证集划分文件
train_cifar.py 复现论文训练cifar数据集的文件
train.py 模型训练文件
eval.py 模型验证文件

解压数据集

In [ ]
!unzip -q data/data70196/animals.zip -d work/dataset
登录后复制

查看图片

In [ ]
import osimport randomfrom matplotlib import pyplot as pltfrom PIL import Image

imgs = []
paths = os.listdir('work/dataset')for path in paths:   
    img_path = os.path.join('work/dataset', path)    if os.path.isdir(img_path):
        img_paths = os.listdir(img_path)
        img = Image.open(os.path.join(img_path, random.choice(img_paths)))
        imgs.append((img, path))

f, ax = plt.subplots(3, 3, figsize=(12,12))for i, img in enumerate(imgs[:9]):
    ax[i//3, i%3].imshow(img[0])
    ax[i//3, i%3].axis('off')
    ax[i//3, i%3].set_title('label: %s' % img[1])
plt.show()
登录后复制
<Figure size 864x864 with 9 Axes>
登录后复制

划分训练集和验证集

In [ ]
!python code/train_val_split.py
登录后复制
finished train val split!
登录后复制

打印网络结构

In [7]
%cd code
登录后复制
/home/aistudio/code
登录后复制
In [10]
import paddle
登录后复制
In [11]
from PyramidNet import PyramidNet
model_basic = PyramidNet('imagenet', 32, 300, num_classes=10, bottleneck=False)
model_basic = paddle.Model(model_basic)
model_basic.summary((-1, 3, 224, 224))
登录后复制
=> the layer configuration for each stage is set to [3, 3, 3, 3]
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Conv2D-26     [[1, 3, 224, 224]]   [1, 64, 112, 112]        9,408     
BatchNorm2D-39  [[1, 64, 112, 112]]   [1, 64, 112, 112]         256      
    ReLU-15     [[1, 64, 112, 112]]   [1, 64, 112, 112]          0       
  MaxPool2D-2   [[1, 64, 112, 112]]    [1, 64, 56, 56]           0       
BatchNorm2D-40   [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
   Conv2D-27     [[1, 64, 56, 56]]     [1, 89, 56, 56]        51,264     
BatchNorm2D-41   [[1, 89, 56, 56]]     [1, 89, 56, 56]          356      
    ReLU-16      [[1, 89, 56, 56]]     [1, 89, 56, 56]           0       
   Conv2D-28     [[1, 89, 56, 56]]     [1, 89, 56, 56]        71,289     
BatchNorm2D-42   [[1, 89, 56, 56]]     [1, 89, 56, 56]          356      
 BasicBlock-13   [[1, 64, 56, 56]]     [1, 89, 56, 56]           0       
BatchNorm2D-43   [[1, 89, 56, 56]]     [1, 89, 56, 56]          356      
   Conv2D-29     [[1, 89, 56, 56]]     [1, 114, 56, 56]       91,314     
BatchNorm2D-44   [[1, 114, 56, 56]]    [1, 114, 56, 56]         456      
    ReLU-17      [[1, 114, 56, 56]]    [1, 114, 56, 56]          0       
   Conv2D-30     [[1, 114, 56, 56]]    [1, 114, 56, 56]       116,964    
BatchNorm2D-45   [[1, 114, 56, 56]]    [1, 114, 56, 56]         456      
 BasicBlock-14   [[1, 89, 56, 56]]     [1, 114, 56, 56]          0       
BatchNorm2D-46   [[1, 114, 56, 56]]    [1, 114, 56, 56]         456      
   Conv2D-31     [[1, 114, 56, 56]]    [1, 139, 56, 56]       142,614    
BatchNorm2D-47   [[1, 139, 56, 56]]    [1, 139, 56, 56]         556      
    ReLU-18      [[1, 139, 56, 56]]    [1, 139, 56, 56]          0       
   Conv2D-32     [[1, 139, 56, 56]]    [1, 139, 56, 56]       173,889    
BatchNorm2D-48   [[1, 139, 56, 56]]    [1, 139, 56, 56]         556      
 BasicBlock-15   [[1, 114, 56, 56]]    [1, 139, 56, 56]          0       
BatchNorm2D-49   [[1, 139, 56, 56]]    [1, 139, 56, 56]         556      
   Conv2D-33     [[1, 139, 56, 56]]    [1, 164, 28, 28]       205,164    
BatchNorm2D-50   [[1, 164, 28, 28]]    [1, 164, 28, 28]         656      
    ReLU-19      [[1, 164, 28, 28]]    [1, 164, 28, 28]          0       
   Conv2D-34     [[1, 164, 28, 28]]    [1, 164, 28, 28]       242,064    
BatchNorm2D-51   [[1, 164, 28, 28]]    [1, 164, 28, 28]         656      
  AvgPool2D-5    [[1, 139, 56, 56]]    [1, 139, 28, 28]          0       
 BasicBlock-16   [[1, 139, 56, 56]]    [1, 164, 28, 28]          0       
BatchNorm2D-52   [[1, 164, 28, 28]]    [1, 164, 28, 28]         656      
   Conv2D-35     [[1, 164, 28, 28]]    [1, 189, 28, 28]       278,964    
BatchNorm2D-53   [[1, 189, 28, 28]]    [1, 189, 28, 28]         756      
    ReLU-20      [[1, 189, 28, 28]]    [1, 189, 28, 28]          0       
   Conv2D-36     [[1, 189, 28, 28]]    [1, 189, 28, 28]       321,489    
BatchNorm2D-54   [[1, 189, 28, 28]]    [1, 189, 28, 28]         756      
 BasicBlock-17   [[1, 164, 28, 28]]    [1, 189, 28, 28]          0       
BatchNorm2D-55   [[1, 189, 28, 28]]    [1, 189, 28, 28]         756      
   Conv2D-37     [[1, 189, 28, 28]]    [1, 214, 28, 28]       364,014    
BatchNorm2D-56   [[1, 214, 28, 28]]    [1, 214, 28, 28]         856      
    ReLU-21      [[1, 214, 28, 28]]    [1, 214, 28, 28]          0       
   Conv2D-38     [[1, 214, 28, 28]]    [1, 214, 28, 28]       412,164    
BatchNorm2D-57   [[1, 214, 28, 28]]    [1, 214, 28, 28]         856      
 BasicBlock-18   [[1, 189, 28, 28]]    [1, 214, 28, 28]          0       
BatchNorm2D-58   [[1, 214, 28, 28]]    [1, 214, 28, 28]         856      
   Conv2D-39     [[1, 214, 28, 28]]    [1, 239, 14, 14]       460,314    
BatchNorm2D-59   [[1, 239, 14, 14]]    [1, 239, 14, 14]         956      
    ReLU-22      [[1, 239, 14, 14]]    [1, 239, 14, 14]          0       
   Conv2D-40     [[1, 239, 14, 14]]    [1, 239, 14, 14]       514,089    
BatchNorm2D-60   [[1, 239, 14, 14]]    [1, 239, 14, 14]         956      
  AvgPool2D-6    [[1, 214, 28, 28]]    [1, 214, 14, 14]          0       
 BasicBlock-19   [[1, 214, 28, 28]]    [1, 239, 14, 14]          0       
BatchNorm2D-61   [[1, 239, 14, 14]]    [1, 239, 14, 14]         956      
   Conv2D-41     [[1, 239, 14, 14]]    [1, 264, 14, 14]       567,864    
BatchNorm2D-62   [[1, 264, 14, 14]]    [1, 264, 14, 14]        1,056     
    ReLU-23      [[1, 264, 14, 14]]    [1, 264, 14, 14]          0       
   Conv2D-42     [[1, 264, 14, 14]]    [1, 264, 14, 14]       627,264    
BatchNorm2D-63   [[1, 264, 14, 14]]    [1, 264, 14, 14]        1,056     
 BasicBlock-20   [[1, 239, 14, 14]]    [1, 264, 14, 14]          0       
BatchNorm2D-64   [[1, 264, 14, 14]]    [1, 264, 14, 14]        1,056     
   Conv2D-43     [[1, 264, 14, 14]]    [1, 289, 14, 14]       686,664    
BatchNorm2D-65   [[1, 289, 14, 14]]    [1, 289, 14, 14]        1,156     
    ReLU-24      [[1, 289, 14, 14]]    [1, 289, 14, 14]          0       
   Conv2D-44     [[1, 289, 14, 14]]    [1, 289, 14, 14]       751,689    
BatchNorm2D-66   [[1, 289, 14, 14]]    [1, 289, 14, 14]        1,156     
 BasicBlock-21   [[1, 264, 14, 14]]    [1, 289, 14, 14]          0       
BatchNorm2D-67   [[1, 289, 14, 14]]    [1, 289, 14, 14]        1,156     
   Conv2D-45     [[1, 289, 14, 14]]     [1, 314, 7, 7]        816,714    
BatchNorm2D-68    [[1, 314, 7, 7]]      [1, 314, 7, 7]         1,256     
    ReLU-25       [[1, 314, 7, 7]]      [1, 314, 7, 7]           0       
   Conv2D-46      [[1, 314, 7, 7]]      [1, 314, 7, 7]        887,364    
BatchNorm2D-69    [[1, 314, 7, 7]]      [1, 314, 7, 7]         1,256     
  AvgPool2D-7    [[1, 289, 14, 14]]     [1, 289, 7, 7]           0       
 BasicBlock-22   [[1, 289, 14, 14]]     [1, 314, 7, 7]           0       
BatchNorm2D-70    [[1, 314, 7, 7]]      [1, 314, 7, 7]         1,256     
   Conv2D-47      [[1, 314, 7, 7]]      [1, 339, 7, 7]        958,014    
BatchNorm2D-71    [[1, 339, 7, 7]]      [1, 339, 7, 7]         1,356     
    ReLU-26       [[1, 339, 7, 7]]      [1, 339, 7, 7]           0       
   Conv2D-48      [[1, 339, 7, 7]]      [1, 339, 7, 7]       1,034,289   
BatchNorm2D-72    [[1, 339, 7, 7]]      [1, 339, 7, 7]         1,356     
 BasicBlock-23    [[1, 314, 7, 7]]      [1, 339, 7, 7]           0       
BatchNorm2D-73    [[1, 339, 7, 7]]      [1, 339, 7, 7]         1,356     
   Conv2D-49      [[1, 339, 7, 7]]      [1, 364, 7, 7]       1,110,564   
BatchNorm2D-74    [[1, 364, 7, 7]]      [1, 364, 7, 7]         1,456     
    ReLU-27       [[1, 364, 7, 7]]      [1, 364, 7, 7]           0       
   Conv2D-50      [[1, 364, 7, 7]]      [1, 364, 7, 7]       1,192,464   
BatchNorm2D-75    [[1, 364, 7, 7]]      [1, 364, 7, 7]         1,456     
 BasicBlock-24    [[1, 339, 7, 7]]      [1, 364, 7, 7]           0       
BatchNorm2D-76    [[1, 364, 7, 7]]      [1, 364, 7, 7]         1,456     
    ReLU-28       [[1, 364, 7, 7]]      [1, 364, 7, 7]           0       
  AvgPool2D-8     [[1, 364, 7, 7]]      [1, 364, 1, 1]           0       
   Linear-2          [[1, 364]]            [1, 10]             3,650     
===========================================================================
Total params: 12,124,672
Trainable params: 12,091,544
Non-trainable params: 33,128
---------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 114.81
Params size (MB): 46.25
Estimated Total Size (MB): 161.63
---------------------------------------------------------------------------
登录后复制
{'total_params': 12124672, 'trainable_params': 12091544}
登录后复制
In [12]
from PyramidNet import PyramidNet
model_bottleneck = PyramidNet('imagenet', 32, 300, num_classes=10, bottleneck=True)
model_bottleneck = paddle.Model(model_bottleneck)
model_bottleneck.summary((-1, 3, 224, 224))
登录后复制
=> the layer configuration for each stage is set to [2, 2, 2, 2]
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Conv2D-51     [[1, 3, 224, 224]]   [1, 64, 112, 112]        9,408     
BatchNorm2D-77  [[1, 64, 112, 112]]   [1, 64, 112, 112]         256      
    ReLU-29     [[1, 64, 112, 112]]   [1, 64, 112, 112]          0       
  MaxPool2D-3   [[1, 64, 112, 112]]    [1, 64, 56, 56]           0       
BatchNorm2D-78   [[1, 64, 56, 56]]     [1, 64, 56, 56]          256      
   Conv2D-52     [[1, 64, 56, 56]]     [1, 102, 56, 56]        6,528     
BatchNorm2D-79   [[1, 102, 56, 56]]    [1, 102, 56, 56]         408      
    ReLU-30      [[1, 102, 56, 56]]    [1, 102, 56, 56]          0       
   Conv2D-53     [[1, 102, 56, 56]]    [1, 102, 56, 56]       93,636     
BatchNorm2D-80   [[1, 102, 56, 56]]    [1, 102, 56, 56]         408      
   Conv2D-54     [[1, 102, 56, 56]]    [1, 408, 56, 56]       41,616     
BatchNorm2D-81   [[1, 408, 56, 56]]    [1, 408, 56, 56]        1,632     
 Bottleneck-1    [[1, 64, 56, 56]]     [1, 408, 56, 56]          0       
BatchNorm2D-82   [[1, 408, 56, 56]]    [1, 408, 56, 56]        1,632     
   Conv2D-55     [[1, 408, 56, 56]]    [1, 139, 56, 56]       56,712     
BatchNorm2D-83   [[1, 139, 56, 56]]    [1, 139, 56, 56]         556      
    ReLU-31      [[1, 139, 56, 56]]    [1, 139, 56, 56]          0       
   Conv2D-56     [[1, 139, 56, 56]]    [1, 139, 56, 56]       173,889    
BatchNorm2D-84   [[1, 139, 56, 56]]    [1, 139, 56, 56]         556      
   Conv2D-57     [[1, 139, 56, 56]]    [1, 556, 56, 56]       77,284     
BatchNorm2D-85   [[1, 556, 56, 56]]    [1, 556, 56, 56]        2,224     
 Bottleneck-2    [[1, 408, 56, 56]]    [1, 556, 56, 56]          0       
BatchNorm2D-86   [[1, 556, 56, 56]]    [1, 556, 56, 56]        2,224     
   Conv2D-58     [[1, 556, 56, 56]]    [1, 176, 56, 56]       97,856     
BatchNorm2D-87   [[1, 176, 56, 56]]    [1, 176, 56, 56]         704      
    ReLU-32      [[1, 176, 28, 28]]    [1, 176, 28, 28]          0       
   Conv2D-59     [[1, 176, 56, 56]]    [1, 176, 28, 28]       278,784    
BatchNorm2D-88   [[1, 176, 28, 28]]    [1, 176, 28, 28]         704      
   Conv2D-60     [[1, 176, 28, 28]]    [1, 704, 28, 28]       123,904    
BatchNorm2D-89   [[1, 704, 28, 28]]    [1, 704, 28, 28]        2,816     
  AvgPool2D-9    [[1, 556, 56, 56]]    [1, 556, 28, 28]          0       
 Bottleneck-3    [[1, 556, 56, 56]]    [1, 704, 28, 28]          0       
BatchNorm2D-90   [[1, 704, 28, 28]]    [1, 704, 28, 28]        2,816     
   Conv2D-61     [[1, 704, 28, 28]]    [1, 214, 28, 28]       150,656    
BatchNorm2D-91   [[1, 214, 28, 28]]    [1, 214, 28, 28]         856      
    ReLU-33      [[1, 214, 28, 28]]    [1, 214, 28, 28]          0       
   Conv2D-62     [[1, 214, 28, 28]]    [1, 214, 28, 28]       412,164    
BatchNorm2D-92   [[1, 214, 28, 28]]    [1, 214, 28, 28]         856      
   Conv2D-63     [[1, 214, 28, 28]]    [1, 856, 28, 28]       183,184    
BatchNorm2D-93   [[1, 856, 28, 28]]    [1, 856, 28, 28]        3,424     
 Bottleneck-4    [[1, 704, 28, 28]]    [1, 856, 28, 28]          0       
BatchNorm2D-94   [[1, 856, 28, 28]]    [1, 856, 28, 28]        3,424     
   Conv2D-64     [[1, 856, 28, 28]]    [1, 252, 28, 28]       215,712    
BatchNorm2D-95   [[1, 252, 28, 28]]    [1, 252, 28, 28]        1,008     
    ReLU-34      [[1, 252, 14, 14]]    [1, 252, 14, 14]          0       
   Conv2D-65     [[1, 252, 28, 28]]    [1, 252, 14, 14]       571,536    
BatchNorm2D-96   [[1, 252, 14, 14]]    [1, 252, 14, 14]        1,008     
   Conv2D-66     [[1, 252, 14, 14]]   [1, 1008, 14, 14]       254,016    
BatchNorm2D-97  [[1, 1008, 14, 14]]   [1, 1008, 14, 14]        4,032     
 AvgPool2D-10    [[1, 856, 28, 28]]    [1, 856, 14, 14]          0       
 Bottleneck-5    [[1, 856, 28, 28]]   [1, 1008, 14, 14]          0       
BatchNorm2D-98  [[1, 1008, 14, 14]]   [1, 1008, 14, 14]        4,032     
   Conv2D-67    [[1, 1008, 14, 14]]    [1, 289, 14, 14]       291,312    
BatchNorm2D-99   [[1, 289, 14, 14]]    [1, 289, 14, 14]        1,156     
    ReLU-35      [[1, 289, 14, 14]]    [1, 289, 14, 14]          0       
   Conv2D-68     [[1, 289, 14, 14]]    [1, 289, 14, 14]       751,689    
BatchNorm2D-100  [[1, 289, 14, 14]]    [1, 289, 14, 14]        1,156     
   Conv2D-69     [[1, 289, 14, 14]]   [1, 1156, 14, 14]       334,084    
BatchNorm2D-101 [[1, 1156, 14, 14]]   [1, 1156, 14, 14]        4,624     
 Bottleneck-6   [[1, 1008, 14, 14]]   [1, 1156, 14, 14]          0       
BatchNorm2D-102 [[1, 1156, 14, 14]]   [1, 1156, 14, 14]        4,624     
   Conv2D-70    [[1, 1156, 14, 14]]    [1, 326, 14, 14]       376,856    
BatchNorm2D-103  [[1, 326, 14, 14]]    [1, 326, 14, 14]        1,304     
    ReLU-36       [[1, 326, 7, 7]]      [1, 326, 7, 7]           0       
   Conv2D-71     [[1, 326, 14, 14]]     [1, 326, 7, 7]        956,484    
BatchNorm2D-104   [[1, 326, 7, 7]]      [1, 326, 7, 7]         1,304     
   Conv2D-72      [[1, 326, 7, 7]]     [1, 1304, 7, 7]        425,104    
BatchNorm2D-105  [[1, 1304, 7, 7]]     [1, 1304, 7, 7]         5,216     
 AvgPool2D-11   [[1, 1156, 14, 14]]    [1, 1156, 7, 7]           0       
 Bottleneck-7   [[1, 1156, 14, 14]]    [1, 1304, 7, 7]           0       
BatchNorm2D-106  [[1, 1304, 7, 7]]     [1, 1304, 7, 7]         5,216     
   Conv2D-73     [[1, 1304, 7, 7]]      [1, 364, 7, 7]        474,656    
BatchNorm2D-107   [[1, 364, 7, 7]]      [1, 364, 7, 7]         1,456     
    ReLU-37       [[1, 364, 7, 7]]      [1, 364, 7, 7]           0       
   Conv2D-74      [[1, 364, 7, 7]]      [1, 364, 7, 7]       1,192,464   
BatchNorm2D-108   [[1, 364, 7, 7]]      [1, 364, 7, 7]         1,456     
   Conv2D-75      [[1, 364, 7, 7]]     [1, 1456, 7, 7]        529,984    
BatchNorm2D-109  [[1, 1456, 7, 7]]     [1, 1456, 7, 7]         5,824     
 Bottleneck-8    [[1, 1304, 7, 7]]     [1, 1456, 7, 7]           0       
BatchNorm2D-110  [[1, 1456, 7, 7]]     [1, 1456, 7, 7]         5,824     
    ReLU-38      [[1, 1456, 7, 7]]     [1, 1456, 7, 7]           0       
 AvgPool2D-12    [[1, 1456, 7, 7]]     [1, 1456, 1, 1]           0       
   Linear-3         [[1, 1456]]            [1, 10]            14,570     
===========================================================================
Total params: 8,169,080
Trainable params: 8,094,088
Non-trainable params: 74,992
---------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 228.82
Params size (MB): 31.16
Estimated Total Size (MB): 260.56
---------------------------------------------------------------------------
登录后复制
{'total_params': 8169080, 'trainable_params': 8094088}
登录后复制

复现论文原文对Cifar10和Cifar100的训练和验证

首先复现论文中对Cifar10和Cifar100的训练和验证,超参数采用论文中的值。

In [ ]
!python train_cifar.py --net 'pyramidnet_bottleneck' --dataset 'cifar10' --alpha 200 --depth 272 --num_classes 10 --epochs 300 --lr 0.1 --batch_size 128
登录后复制
Eval samples: 10000
登录后复制

图示cifar10训练验证过程

牛年识牛,复现金字塔网络PyramidNet实现动物分类 - php中文网

图4. 使用PyramidNet_bottleneck对cifar10的训练验证图示

In [ ]
!python train_cifar.py --net 'pyramidnet_bottleneck' --dataset 'cifar100'  --alpha 200 --depth 272 --num_classes 100 --epochs 300 --lr 0.5 --batch_size 128
登录后复制
Eval samples: 10000
登录后复制

图示cifar100训练验证过程

牛年识牛,复现金字塔网络PyramidNet实现动物分类 - php中文网

图5. 使用PyramidNet_bottleneck对cifar100的训练验证图示

使用PyramidNet网络进行动物分类的训练和验证

训练

In [16]
!python train.py --net 'pyramidnet'
登录后复制
In [15]
!python train.py --net 'pyramidnet_bottleneck'
登录后复制

验证

In [ ]
!python eval.py --net 'pyramidnet'
登录后复制
=> the layer configuration for each stage is set to [3, 3, 3, 3]
W0213 21:36:28.090073 12789 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0213 21:36:28.094696 12789 device_context.cc:372] device: 0, cuDNN Version: 7.6.
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 103/103 [==============================] - loss: 0.8536 - acc: 0.7357 - 187ms/step        
Eval samples: 3276
{'loss': [0.8535936], 'acc': 0.7356532356532357}
登录后复制
In [ ]
!python eval.py --net 'pyramidnet_bottleneck'
登录后复制
=> the layer configuration for each stage is set to [2, 2, 2, 2]
W0214 09:30:31.742739 30668 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0214 09:30:31.747535 30668 device_context.cc:372] device: 0, cuDNN Version: 7.6.
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 103/103 [==============================] - loss: 1.2630 - acc: 0.6734 - 194ms/step        
Eval samples: 3276
{'loss': [1.2630014], 'acc': 0.6733821733821734}
登录后复制

图示训练验证过程

牛年识牛,复现金字塔网络PyramidNet实现动物分类 - php中文网

图6. 使用PyramidNet的训练验证图示

使用resnet50网络进行动物分类的训练并验证

训练

In [14]
!python train.py --net 'resnet'
登录后复制

验证

In [ ]
!python eval.py --net 'resnet'
登录后复制
W0213 21:34:50.038996 12684 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0213 21:34:50.043457 12684 device_context.cc:372] device: 0, cuDNN Version: 7.6.
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 103/103 [==============================] - loss: 1.4232 - acc: 0.5888 - 191ms/step        
Eval samples: 3276
{'loss': [1.4232028], 'acc': 0.5888278388278388}
登录后复制

图示训练验证过程

牛年识牛,复现金字塔网络PyramidNet实现动物分类 - php中文网

图7. 使用ResNet的训练验证图示

比较

使用该项目的参数配置,即PyramidNet-32(alpha=300 without bottleneck)和PyramidNet-32(alpha=300 with bottleneck)相较于

resnet50模型具有更高的准确率,却具有少的多的参数。其中,resnet50的模型大小有141.5M,而PyramidNet-32(alpha=300 without bottleneck)

的模型大小有72.6M,PyramidNet-32(alpha=300 with bottleneck)的模型大小仅有48.9M。

牛年识牛,复现金字塔网络PyramidNet实现动物分类 - php中文网

图8. 使用PyramidNet和ResNet的训练验证比较图示

以上就是牛年识牛,复现金字塔网络PyramidNet实现动物分类的详细内容,更多请关注php中文网其它相关文章!

最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习
PHP中文网抖音号
发现有趣的

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号