0

0

Keras模型训练与评估精度不一致问题解析与解决方案

花韻仙語

花韻仙語

发布时间:2025-08-11 17:20:02

|

1030人浏览过

|

来源于php中文网

原创

Keras模型训练与评估精度不一致问题解析与解决方案

本文深入探讨了Keras模型在训练过程中(model.fit)报告的精度与模型评估(model.evaluate)精度不一致的常见问题。通过分析两者计算机制的差异,特别是批量更新和指标聚合方式,揭示了产生差异的根本原因。文章提供了通过引入validation_data并在自定义回调中监控val_accuracy的解决方案,确保训练过程中的监控指标与最终评估结果保持一致,从而提高模型训练的可靠性和可解释性。

1. 问题现象与初步分析

在使用keras进行模型训练时,我们可能会观察到model.fit在每个epoch结束时打印的accuracy(训练精度)与训练结束后使用model.evaluate在相同训练集上计算得到的精度存在差异。例如,fit报告的精度可能达到1.0,而evaluate的结果却略低于1.0。这种差异尤其在自定义回调函数中依赖logs['accuracy']进行逻辑判断(如提前停止)时,可能导致意外的行为。

造成这种差异的根本原因在于model.fit和model.evaluate计算指标的方式不同:

  • model.fit中的训练精度(accuracy):在每个epoch内,模型会分批次(batch)处理数据并更新权重。model.fit报告的accuracy是该epoch内所有批次精度的平均值。重要的是,每个批次的精度是在该批次数据被处理之前(或在权重更新之后但尚未处理下一个批次之前)计算的。这意味着,对于一个epoch内的不同批次,模型的权重可能在不断变化,因此计算出的精度是基于动态变化的模型状态。当一个epoch结束时,报告的accuracy是整个epoch中,模型在处理各个批次时所达到的平均性能。

  • model.evaluate中的精度:model.evaluate函数在调用时,会使用模型当前的最终权重来对整个数据集进行一次性(或分批次)评估。它不会在评估过程中更新权重。因此,model.evaluate的结果代表了模型在固定权重下的整体性能。

当batch_size较小,或者模型在训练初期权重变化较大时,model.fit报告的平均精度与model.evaluate在最终权重下计算的精度之间就可能出现显著差异。

2. 解决方案:引入验证集与监控val_accuracy

解决这一问题的关键在于,让model.fit在每个epoch结束时,使用当前epoch的最终权重,在一个固定数据集上计算指标。这可以通过fit方法的validation_data参数来实现。即使我们希望在训练集上进行评估以比较,也可以将训练集本身作为validation_data。

当validation_data被提供时,Keras会在每个epoch结束时,使用该epoch的最终模型权重对验证数据进行一次评估,并报告val_loss和val_accuracy等指标。这些val_accuracy值将与model.evaluate在相同数据集上得到的结果更加一致,因为它反映的是模型在固定权重下的表现。

MiniMax Agent
MiniMax Agent

MiniMax平台推出的Agent智能体助手

下载

对于自定义的提前停止回调,也应该监控val_accuracy而不是accuracy。

2.1 示例代码(修正后)

以下是修正后的代码示例,展示了如何通过引入validation_data并调整自定义回调来解决精度不一致问题:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import keras
import random
from keras import layers
from keras.callbacks import EarlyStopping
from keras.optimizers import Adam


def random_seed(seed_num=1):
    """
    设置随机种子以确保结果可复现。
    """
    np.random.seed(seed_num)
    tf.random.set_seed(seed_num)
    random.seed(seed_num)


class CustomEarlyStopping(keras.callbacks.Callback):
    """
    自定义提前停止回调,根据验证精度停止训练。
    """
    def __init__(self, threshold):
        super().__init__()
        self.threshold = threshold

    def on_epoch_end(self, epoch, logs=None):
        # 监控 'val_accuracy' 而不是 'accuracy'
        accuracy = logs.get("val_accuracy")
        if accuracy is not None and accuracy >= self.threshold:
            print(f"\n达到验证精度阈值 {self.threshold},停止训练。")
            self.model.stop_training = True

# 1. 数据准备
x = np.arange(-20, 30, 0.1)
y = np.zeros_like(x)
df = pd.DataFrame({'x': x, 'y': y})
# 创建一个简单的二分类问题:x < 10 为 0,否则为 1
df.y = df.x.map(lambda x_val: 0 if x_val < 10 else 1)
X_train = df.drop(columns='y')
y_train = df.y

# 2. 模型构建
random_seed() # 设置随机种子
model = keras.Sequential([
    layers.Input(shape=X_train.shape[-1]),
    layers.Normalization(), # 数据归一化层
    layers.Dense(1, activation='relu'), # 第一个全连接层
    layers.Dense(1, activation='sigmoid'), # 输出层,用于二分类
])

# 3. 模型编译
model.compile(
    optimizer=Adam(learning_rate=0.1), # 使用Adam优化器
    loss='binary_crossentropy', # 二元交叉熵损失函数
    metrics=['accuracy'], # 监控精度
)

# 4. 模型训练
history = model.fit(
    X_train, y_train,
    validation_data=(X_train, y_train), # 关键:将训练集也作为验证集
    batch_size=128,
    epochs=300, # 增加epochs以确保模型充分训练
    callbacks=[
        CustomEarlyStopping(1.0) # 使用自定义提前停止回调
    ]
)
history_df = pd.DataFrame(history.history)

# 5. 结果验证
# 获取history中记录的最后一个训练精度(注意这里仍然是训练精度)
last_accuracy_fit = history_df.accuracy.tolist()[-1]
# 获取history中记录的最后一个验证精度
last_accuracy_val = history_df.val_accuracy.tolist()[-1]
# 使用model.evaluate在训练集上进行评估
predict_accuracy = model.evaluate(X_train, y_train, verbose=0)[-1] # verbose=0 不打印进度条

print(f'Fit报告的最后一个训练精度 (accuracy): {last_accuracy_fit:.6f}')
print(f'Fit报告的最后一个验证精度 (val_accuracy): {last_accuracy_val:.6f}')
print(f'model.evaluate评估的精度: {predict_accuracy:.6f}')

# 预期输出:
# Fit报告的最后一个训练精度 (accuracy): 1.000000
# Fit报告的最后一个验证精度 (val_accuracy): 1.000000
# model.evaluate评估的精度: 1.000000

2.2 修正点解析

  1. validation_data=(X_train, y_train): 在model.fit中加入了validation_data参数,并将训练数据本身作为验证数据传入。这使得Keras在每个epoch结束时,都会使用该epoch的最终模型权重对X_train和y_train进行一次完整的评估,并生成val_accuracy指标。
  2. accuracy = logs.get("val_accuracy"): 在自定义的CustomEarlyStopping回调中,将监控的指标从logs["accuracy"]改为了logs.get("val_accuracy")。这样,提前停止的判断依据就是模型在epoch结束时,使用该epoch的最终权重在固定验证集(这里是训练集)上的表现。
  3. 增加epochs: 为了确保模型有足够的机会达到理想精度,将epochs从100增加到了300。

通过这些修改,model.fit报告的val_accuracy与model.evaluate的结果将保持高度一致,因为它们都是在相同的固定数据集上,使用相同的模型最终权重进行计算的。

3. 注意事项与总结

  • 理解指标含义:始终要区分model.fit在训练过程中报告的批次平均精度(accuracy)和epoch结束时在固定数据集上评估的精度(val_accuracy)。前者是动态的,后者是静态的,更具代表性。
  • 验证集的重要性:在实际项目中,validation_data通常应该是与训练集独立的数据集,用于监控模型的泛化能力,防止过拟合。本教程中为了演示精度一致性问题,使用了训练集作为验证集,但在生产环境中应避免。
  • model.evaluate的权威性:model.evaluate始终是评估模型在特定数据集上最终性能的“黄金标准”,因为它是在模型训练完成后,使用其最终权重进行的一次性评估。
  • 提前停止策略:在使用提前停止回调时,务必基于val_loss或val_accuracy等验证指标进行判断,而不是训练指标,以确保模型在验证集上表现良好时停止训练,避免过拟合。

通过理解Keras内部指标计算的机制并正确配置model.fit的参数,我们可以更准确地监控模型训练过程,并确保训练结果的可靠性。

相关专题

更多
Java 桌面应用开发(JavaFX 实战)
Java 桌面应用开发(JavaFX 实战)

本专题系统讲解 Java 在桌面应用开发领域的实战应用,重点围绕 JavaFX 框架,涵盖界面布局、控件使用、事件处理、FXML、样式美化(CSS)、多线程与UI响应优化,以及桌面应用的打包与发布。通过完整示例项目,帮助学习者掌握 使用 Java 构建现代化、跨平台桌面应用程序的核心能力。

37

2026.01.14

php与html混编教程大全
php与html混编教程大全

本专题整合了php和html混编相关教程,阅读专题下面的文章了解更多详细内容。

19

2026.01.13

PHP 高性能
PHP 高性能

本专题整合了PHP高性能相关教程大全,阅读专题下面的文章了解更多详细内容。

37

2026.01.13

MySQL数据库报错常见问题及解决方法大全
MySQL数据库报错常见问题及解决方法大全

本专题整合了MySQL数据库报错常见问题及解决方法,阅读专题下面的文章了解更多详细内容。

19

2026.01.13

PHP 文件上传
PHP 文件上传

本专题整合了PHP实现文件上传相关教程,阅读专题下面的文章了解更多详细内容。

16

2026.01.13

PHP缓存策略教程大全
PHP缓存策略教程大全

本专题整合了PHP缓存相关教程,阅读专题下面的文章了解更多详细内容。

6

2026.01.13

jQuery 正则表达式相关教程
jQuery 正则表达式相关教程

本专题整合了jQuery正则表达式相关教程大全,阅读专题下面的文章了解更多详细内容。

3

2026.01.13

交互式图表和动态图表教程汇总
交互式图表和动态图表教程汇总

本专题整合了交互式图表和动态图表的相关内容,阅读专题下面的文章了解更多详细内容。

45

2026.01.13

nginx配置文件详细教程
nginx配置文件详细教程

本专题整合了nginx配置文件相关教程详细汇总,阅读专题下面的文章了解更多详细内容。

9

2026.01.13

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
10分钟--Midjourney创作自己的漫画
10分钟--Midjourney创作自己的漫画

共1课时 | 0.1万人学习

Midjourney 关键词系列整合
Midjourney 关键词系列整合

共13课时 | 0.9万人学习

AI绘画教程
AI绘画教程

共2课时 | 0.2万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号