0

0

PyTorch中VGG-19模型的微调策略:全层与特定全连接层更新实践

聖光之護

聖光之護

发布时间:2025-11-22 13:49:20

|

448人浏览过

|

来源于php中文网

原创

PyTorch中VGG-19模型的微调策略:全层与特定全连接层更新实践

本文详细介绍了在pytorch中对预训练vgg-19模型进行微调的两种核心策略:一是更新模型所有层的权重以适应新任务;二是通过冻结大部分层,仅微调vgg-19分类器中的特定全连接层(fc1和fc2)。文章将通过示例代码演示如何精确控制参数的梯度计算,并强调根据新数据集的类别数量调整最终输出层的重要性,从而高效地迁移学习。

深度学习领域,迁移学习是一种强大的技术,它允许我们利用在大规模数据集(如ImageNet)上预训练的模型,并将其应用于新的、通常数据量较小的任务。VGG-19作为一种经典的卷积神经网络架构,因其简洁的结构和强大的特征提取能力,常被用作迁移学习的基石。在PyTorch中,我们可以灵活地控制模型的哪些部分参与训练(即微调),以达到最佳的任务适应性。

VGG-19模型结构概览

VGG-19模型由特征提取器(features)、自适应平均池化层(avgpool)和分类器(classifier)三大部分组成。其中,分类器部分通常包含多个全连接层(Linear layers),用于最终的分类任务。了解其结构对于精确控制微调至关重要。

典型的VGG-19分类器结构如下:

  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True) # FC1
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True) # FC2
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True) # 原始输出层 (Original output layer)
  )

从上述结构可以看出,classifier[0]对应第一个全连接层(FC1),classifier[3]对应第二个全连接层(FC2),而classifier[6]则是原始模型针对ImageNet数据集的1000类输出层。

策略一:微调VGG-19所有层

这种策略适用于新任务与原始预训练任务差异较大,或者新数据集足够大,足以支持对整个网络进行训练的情况。通过微调所有层,模型可以最大限度地适应新任务的特征分布。

实现步骤:

  1. 加载预训练的VGG-19模型。
  2. 将模型的所有参数的requires_grad属性设置为True,确保所有层在训练过程中都会更新权重。
  3. 根据新任务的类别数量,替换模型的最终分类层。

示例代码:

聚蜂消防BeesFPD
聚蜂消防BeesFPD

关注消防领域的智慧云平台

下载
import torch.nn as nn
from torchvision import models
from torchvision.models import VGG19_Weights

# 1. 加载预训练的VGG-19模型
# 推荐使用 weights 参数加载预训练权重
model_all_layers = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1)

# 2. 设置所有层的参数为可训练
for param in model_all_layers.parameters():
    param.requires_grad = True # 确保所有参数都参与梯度计算和更新

# 3. 替换最终分类层以适应新任务的类别数
# 假设您的新数据集有 num_classes 个类别
# 请根据实际情况定义 num_classes,例如:num_classes = len(your_dataset.class_to_idx)
num_classes = 10 # 示例值,请替换为您的实际类别数
in_features = model_all_layers.classifier[6].in_features # 获取原始输出层的输入特征数
model_all_layers.classifier[6] = nn.Linear(in_features, num_classes)

print("VGG-19模型已设置为微调所有层,并更新了最终分类层。")
# 此时,model_all_layers 即可用于训练

策略二:选择性微调特定全连接层(FC1和FC2)

当新数据集相对较小,或者我们希望利用预训练模型强大的特征提取能力,同时避免过拟合时,通常会选择冻结大部分卷积层,只微调分类器中的部分层。这种方法可以有效地在保持模型泛化能力的同时,使其适应特定任务。

实现步骤:

  1. 加载预训练的VGG-19模型。
  2. 首先将模型的所有参数的requires_grad属性设置为False,冻结所有层。
  3. 然后,针对需要微调的特定全连接层(FC1和FC2),将其参数的requires_grad属性设置为True。
  4. 根据新任务的类别数量,替换模型的最终分类层。

示例代码:

import torch.nn as nn
from torchvision import models
from torchvision.models import VGG19_Weights

# 1. 加载预训练的VGG-19模型
model_fc_layers = models.vgg19(weights=VGG19_Weights.IMAGENET1K_V1)

# 2. 冻结所有层的参数
for param in model_fc_layers.parameters():
    param.requires_grad = False # 默认冻结所有层

# 3. 解冻FC1和FC2层的参数
# FC1 对应 classifier[0]
for param in model_fc_layers.classifier[0].parameters():
    param.requires_grad = True

# FC2 对应 classifier[3]
for param in model_fc_layers.classifier[3].parameters():
    param.requires_grad = True

# 4. 替换最终分类层以适应新任务的类别数
# 假设您的新数据集有 num_classes 个类别
num_classes = 10 # 示例值,请替换为您的实际类别数
in_features = model_fc_layers.classifier[6].in_features # 获取原始输出层的输入特征数
model_fc_layers.classifier[6] = nn.Linear(in_features, num_classes)
# 注意:新替换的 nn.Linear 层默认其参数 requires_grad=True,因此无需额外设置

print("VGG-19模型已设置为仅微调FC1、FC2和最终分类层。")
# 此时,model_fc_layers 即可用于训练

关于最终分类层的处理

无论选择哪种微调策略,替换VGG-19模型的最终分类层(即classifier[6])都是一个推荐且通常是必要的步骤。

  • 必要性: 如果您的新任务的类别数量与ImageNet(1000类)不同,那么模型的输出维度必须与新任务的类别数量匹配,否则无法进行正确的损失计算和分类。
  • 推荐性: 即使您的新任务恰好也有1000个类别,但这些类别的具体含义很可能与ImageNet的类别不同。替换并重新训练这个输出层,可以帮助模型更好地学习区分新任务中特定类别的特征,从而提高分类性能。新的nn.Linear层会以随机初始化的权重开始训练,并根据您的数据集进行学习。

注意事项与最佳实践

  1. 加载预训练权重: 在PyTorch 0.13及更高版本中,推荐使用weights=VGG19_Weights.IMAGENET1K_V1来加载预训练权重,而不是已弃用的pretrained=True。
  2. 优化器: 在微调时,可能需要为冻结层和解冻层设置不同的学习率。例如,对于预训练的层使用较小的学习率,对于新添加或解冻的层使用较大的学习率。PyTorch的优化器可以接受参数组,方便实现这一目标。
  3. 数据预处理: 确保您的输入数据经过与ImageNet预训练时相同的预处理步骤,包括图像大小调整(通常为224x224)、归一化(使用ImageNet的均值和标准差)。
  4. 训练循环: 微调过程与从头开始训练模型类似,需要定义损失函数、优化器,并进行迭代训练。
  5. 过拟合: 尤其是在数据集较小的情况下,微调时需要警惕过拟合。可以采用数据增强、Dropout、早停(Early Stopping)等技术来缓解。

总结

本文详细阐述了在PyTorch中对VGG-19模型进行微调的两种主要策略:全面微调和选择性微调特定全连接层。通过精确控制requires_grad属性,我们可以灵活地决定模型哪些部分参与训练,从而根据具体任务和数据集的特点,实现高效的迁移学习。理解并正确应用这些策略,是利用预训练模型解决实际计算机视觉问题的关键。

相关专题

更多
pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

431

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

19

2025.12.22

Java 桌面应用开发(JavaFX 实战)
Java 桌面应用开发(JavaFX 实战)

本专题系统讲解 Java 在桌面应用开发领域的实战应用,重点围绕 JavaFX 框架,涵盖界面布局、控件使用、事件处理、FXML、样式美化(CSS)、多线程与UI响应优化,以及桌面应用的打包与发布。通过完整示例项目,帮助学习者掌握 使用 Java 构建现代化、跨平台桌面应用程序的核心能力。

61

2026.01.14

php与html混编教程大全
php与html混编教程大全

本专题整合了php和html混编相关教程,阅读专题下面的文章了解更多详细内容。

31

2026.01.13

PHP 高性能
PHP 高性能

本专题整合了PHP高性能相关教程大全,阅读专题下面的文章了解更多详细内容。

73

2026.01.13

MySQL数据库报错常见问题及解决方法大全
MySQL数据库报错常见问题及解决方法大全

本专题整合了MySQL数据库报错常见问题及解决方法,阅读专题下面的文章了解更多详细内容。

20

2026.01.13

PHP 文件上传
PHP 文件上传

本专题整合了PHP实现文件上传相关教程,阅读专题下面的文章了解更多详细内容。

24

2026.01.13

PHP缓存策略教程大全
PHP缓存策略教程大全

本专题整合了PHP缓存相关教程,阅读专题下面的文章了解更多详细内容。

7

2026.01.13

jQuery 正则表达式相关教程
jQuery 正则表达式相关教程

本专题整合了jQuery正则表达式相关教程大全,阅读专题下面的文章了解更多详细内容。

4

2026.01.13

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
SciPy 教程
SciPy 教程

共10课时 | 1.1万人学习

R 教程
R 教程

共45课时 | 5万人学习

SQL 教程
SQL 教程

共61课时 | 3.4万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号