0

0

使用 VGG16 进行 MNIST 数字识别的迁移学习

碧海醫心

碧海醫心

发布时间:2025-08-19 19:12:16

|

719人浏览过

|

来源于php中文网

原创

使用 vgg16 进行 mnist 数字识别的迁移学习

本文档旨在指导读者如何利用 VGG16 模型进行 MNIST 手写数字识别的迁移学习。我们将重点介绍如何构建模型,加载预训练权重,以及解决可能出现的 GPU 配置问题。本文将帮助您理解迁移学习的基本原理,并提供一个可运行的示例,用于 MNIST 数据集的分类。

使用 VGG16 进行 MNIST 数字识别的迁移学习

迁移学习是一种强大的技术,它允许我们将一个在大型数据集上训练过的模型(例如,ImageNet)的知识迁移到另一个相关但更小的数据集上(例如,MNIST)。这种方法可以显著提高模型的性能,尤其是在数据量有限的情况下。本文将介绍如何使用 VGG16 模型进行 MNIST 手写数字识别的迁移学习。

准备工作

在开始之前,请确保您已安装以下库:

  • TensorFlow
  • Keras
  • NumPy
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras import layers, models
import numpy as np

构建迁移学习模型

以下代码展示了如何使用 VGG16 模型作为基础模型,并添加自定义层以适应 MNIST 数据集。

class VGG16TransferLearning(tf.keras.Model):
    def __init__(self, base_model, models):
        super(VGG16TransferLearning, self).__init__()
        # base model
        self.base_model = base_model

        # other layers
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(512, activation='relu')
        self.dense2 = tf.keras.layers.Dense(512, activation='relu')
        self.dense3 = tf.keras.layers.Dense(10)
        self.layers_list = [self.flatten, self.dense1, self.dense2, self.dense3]

        # instantiate the base model with other layers
        self.model = models.Sequential(
            [self.base_model, *self.layers_list]
        )

    def call(self, *args, **kwargs):
        out = args[0]
        for layer in self.model.layers:
            out = layer(out)
        return out

在这个类中,VGG16TransferLearning 继承自 tf.keras.Model。__init__ 方法初始化了 VGG16 基础模型,以及一些自定义的 Dense 层。call 方法定义了模型的前向传播过程。

加载 VGG16 预训练权重

在实例化 VGG16TransferLearning 类之前,我们需要加载 VGG16 模型并冻结其权重,以防止在训练过程中修改预训练的权重。

# 假设 x_train[0].shape 是 (75, 75, 3)
input_shape = (75, 75, 3)  # 确保输入形状与您的数据一致

base_model = VGG16(weights="imagenet", include_top=False, input_shape=input_shape)
base_model.trainable = False  # 冻结 VGG16 的权重

VGG16(weights="imagenet", include_top=False, input_shape=x_train[0].shape) 这行代码加载了在 ImageNet 上预训练的 VGG16 模型。include_top=False 表示我们不包含 VGG16 的顶层分类器,因为我们需要根据 MNIST 数据集进行自定义。 base_model.trainable = False 冻结了 VGG16 模型的权重,防止在训练过程中更新这些权重。

编译和训练模型

接下来,我们将实例化 VGG16TransferLearning 类,编译模型,并使用 MNIST 数据集进行训练。

model = VGG16TransferLearning(base_model, models)

model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.legacy.Adam(),
              metrics=['accuracy'])

# 假设 x_train, y_train, x_test, y_test 已经加载
# 并且 x_train 的形状是 (num_samples, 75, 75, 3)
# 并且 y_train 是整数标签
model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))

model.compile 方法配置了模型的训练过程,包括损失函数、优化器和评估指标。model.fit 方法使用训练数据训练模型,并使用验证数据评估模型的性能。

GPU 配置问题

在运行上述代码时,可能会遇到 GPU 配置问题,导致 Kernel 死机。这通常是因为 TensorFlow 没有正确检测到 GPU。以下是一些可能的解决方案:

  1. 检查 TensorFlow 是否检测到 GPU:

    PHP高级开发技巧与范例
    PHP高级开发技巧与范例

    PHP是一种功能强大的网络程序设计语言,而且易学易用,移植性和可扩展性也都非常优秀,本书将为读者详细介绍PHP编程。 全书分为预备篇、开始篇和加速篇三大部分,共9章。预备篇主要介绍一些学习PHP语言的预备知识以及PHP运行平台的架设;开始篇则较为详细地向读者介绍PKP语言的基本语法和常用函数,以及用PHP如何对MySQL数据库进行操作;加速篇则通过对典型实例的介绍来使读者全面掌握PHP。 本书

    下载
    print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

    如果输出为 0,则表示 TensorFlow 没有检测到 GPU。

  2. 安装正确的 TensorFlow 版本:

    确保您安装了与您的 GPU 兼容的 TensorFlow 版本。例如,如果您使用的是 NVIDIA GPU,则需要安装 tensorflow-gpu 版本。

  3. 配置 CUDA 和 cuDNN:

    确保您已正确安装和配置 CUDA 和 cuDNN,并且它们的版本与 TensorFlow 兼容。

  4. 设置环境变量:

    设置以下环境变量:

    • CUDA_HOME: CUDA 的安装路径。
    • LD_LIBRARY_PATH: CUDA 库的路径。

总结

本文介绍了如何使用 VGG16 模型进行 MNIST 手写数字识别的迁移学习。我们讨论了如何构建模型,加载预训练权重,以及解决可能出现的 GPU 配置问题。通过迁移学习,我们可以利用在大规模数据集上训练的模型,快速构建高性能的图像分类器。

注意事项:

  • 确保输入数据的形状与 VGG16 模型的输入形状兼容。
  • 根据您的数据集和计算资源,调整模型的超参数,例如学习率和 batch size。
  • 如果遇到 GPU 配置问题,请仔细检查 TensorFlow、CUDA 和 cuDNN 的安装和配置。

相关专题

更多
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

20

2025.12.22

Python 深度学习框架与TensorFlow入门
Python 深度学习框架与TensorFlow入门

本专题深入讲解 Python 在深度学习与人工智能领域的应用,包括使用 TensorFlow 搭建神经网络模型、卷积神经网络(CNN)、循环神经网络(RNN)、数据预处理、模型优化与训练技巧。通过实战项目(如图像识别与文本生成),帮助学习者掌握 如何使用 TensorFlow 开发高效的深度学习模型,并将其应用于实际的 AI 问题中。

17

2026.01.07

Java 桌面应用开发(JavaFX 实战)
Java 桌面应用开发(JavaFX 实战)

本专题系统讲解 Java 在桌面应用开发领域的实战应用,重点围绕 JavaFX 框架,涵盖界面布局、控件使用、事件处理、FXML、样式美化(CSS)、多线程与UI响应优化,以及桌面应用的打包与发布。通过完整示例项目,帮助学习者掌握 使用 Java 构建现代化、跨平台桌面应用程序的核心能力。

61

2026.01.14

php与html混编教程大全
php与html混编教程大全

本专题整合了php和html混编相关教程,阅读专题下面的文章了解更多详细内容。

31

2026.01.13

PHP 高性能
PHP 高性能

本专题整合了PHP高性能相关教程大全,阅读专题下面的文章了解更多详细内容。

73

2026.01.13

MySQL数据库报错常见问题及解决方法大全
MySQL数据库报错常见问题及解决方法大全

本专题整合了MySQL数据库报错常见问题及解决方法,阅读专题下面的文章了解更多详细内容。

20

2026.01.13

PHP 文件上传
PHP 文件上传

本专题整合了PHP实现文件上传相关教程,阅读专题下面的文章了解更多详细内容。

24

2026.01.13

PHP缓存策略教程大全
PHP缓存策略教程大全

本专题整合了PHP缓存相关教程,阅读专题下面的文章了解更多详细内容。

7

2026.01.13

jQuery 正则表达式相关教程
jQuery 正则表达式相关教程

本专题整合了jQuery正则表达式相关教程大全,阅读专题下面的文章了解更多详细内容。

4

2026.01.13

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
10分钟--Midjourney创作自己的漫画
10分钟--Midjourney创作自己的漫画

共1课时 | 0.1万人学习

Midjourney 关键词系列整合
Midjourney 关键词系列整合

共13课时 | 0.9万人学习

AI绘画教程
AI绘画教程

共2课时 | 0.2万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号