0

0

Autokeras中标签编码、随机种子对模型性能的影响及复现性策略

聖光之護

聖光之護

发布时间:2025-09-23 23:26:16

|

337人浏览过

|

来源于php中文网

原创

Autokeras中标签编码、随机种子对模型性能的影响及复现性策略

在使用Autokeras的StructuredDataClassifier时,直接使用One-Hot编码标签与转换为整数标签可能导致显著的性能差异。这种差异并非源于Autokeras对标签处理方式的根本性错误,而是通常与随机种子在模型训练和超参数搜索过程中的影响密切相关。为确保模型性能的稳定性和实验结果的可复现性,正确设置随机种子并理解Autokeras的内部机制至关重要。

Autokeras中的标签处理机制

在机器学习分类任务中,标签编码是数据预处理的关键一步。常见的编码方式包括one-hot编码和整数编码。对于autokeras的structureddataclassifier,它被设计为处理分类任务,通常期望接收整数形式的类别标签。即使您提供one-hot编码的标签,autokeras在内部处理时也会将其视为分类问题,并在其内部管道中进行相应的转换和处理。

实际上,autokeras在接收到整数标签后,会自行将其转换为One-Hot编码形式,以便与通常用于多分类任务的损失函数(如CategoricalCrossentropy)兼容。您可以通过检查clf.outputs[0].in_blocks[0].get_hyper_preprocessors()来验证其预处理器链中是否存在OneHotEncoder对象,以及通过clf.outputs[0].in_blocks[0].loss来确认所使用的损失函数。这意味着,无论您是提供原始的One-Hot编码还是转换后的整数标签,最终模型训练使用的内部标签表示和损失函数通常是一致的。因此,当观察到两者之间存在巨大性能差异(例如从0.40到0.97)时,问题往往不在于标签编码的“正确性”,而在于其他因素。

随机种子与模型复现性

Autokeras作为一种自动化机器学习(AutoML)工具,在寻找最佳模型架构和超参数时,会执行大量的随机操作,例如:

  • 超参数搜索空间探索: 不同的随机初始化可能导致搜索算法探索不同的超参数组合。
  • 模型权重初始化: 神经网络的初始权重通常是随机的。
  • 数据洗牌: 训练数据在每个epoch开始前通常会被随机洗牌。
  • Dropout层: Dropout操作本身具有随机性。

这些随机性在每次运行代码时都可能产生不同的结果,尤其是在max_trials(最大尝试次数)参数较小的情况下。当随机性导致模型在超参数搜索阶段选择了一个次优架构或初始化了一个不利的权重集时,即使输入数据和标签处理方式看似正确,也可能导致性能急剧下降。这正是本案例中观察到One-Hot编码直接输入导致低准确率(0.40)而整数编码导致高准确率(0.97)的根本原因——不同的随机种子导致了不同的超参数搜索路径和最终模型。

确保Autokeras模型复现性的策略

为了解决随机性带来的性能波动问题,并确保实验结果的可复现性,我们需要显式地设置随机种子。仅仅在StructuredDataClassifier构造函数中设置seed参数可能不足以完全控制所有随机源。更全面的方法是使用Keras提供的工具来设置全局随机种子。

以下是确保Autokeras模型复现性的推荐步骤:

Fish Audio
Fish Audio

为所有人准备的音频 AI

下载
  1. 全局设置随机种子: 在脚本的开头,使用keras.utils.set_random_seed()来设置所有涉及Keras和TensorFlow操作的随机种子。

    import numpy as np
    import tensorflow as tf
    import os
    import autokeras as ak
    import keras # 导入keras
    
    # 设置随机种子以确保复现性
    random_seed = 42 # 选择一个你喜欢的整数
    keras.utils.set_random_seed(random_seed)
    tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True) # 如果使用GPU,可选
  2. 初始化Autokeras分类器时指定种子和覆盖模式: 在初始化StructuredDataClassifier时,除了设置seed参数外,还建议设置overwrite=True。overwrite=True可以确保每次运行时都会从头开始进行超参数搜索,而不会加载之前运行的结果,从而避免潜在的干扰。

    # 初始化结构化数据分类器
    # overwrite=True 确保每次运行都重新开始搜索,不加载之前的结果
    # seed 参数进一步确保 autokeras 内部的随机性可控
    clf = ak.StructuredDataClassifier(overwrite=True, max_trials=10, seed=random_seed)
  3. 增加max_trials以稳定结果(可选但推荐):max_trials参数决定了Autokeras尝试的不同模型架构和超参数组合的数量。当max_trials较小(例如默认的10)时,超参数搜索可能不够充分,导致结果对随机种子非常敏感。增加max_trials(例如设置为50或100)可以使搜索过程更全面,从而提高找到稳定且高性能模型的概率,减少不同随机种子带来的结果波动。

优化标签编码实践

尽管Autokeras能够内部处理One-Hot编码,但为了代码的清晰性和与大多数分类API的约定保持一致,建议在将数据传递给StructuredDataClassifier之前,将One-Hot编码的标签转换为整数标签。这简化了tf.data.Dataset.from_generator的output_signature定义,并使标签的含义更加直观。

以下是转换为整数标签的示例代码片段:

import numpy as np
import tensorflow as tf
import os
import autokeras as ak
import keras

# 设置随机种子
random_seed = 42
keras.utils.set_random_seed(random_seed)

N_FEATURES = 8
N_CLASSES = 3
BATCH_SIZE = 100

def get_data_generator(folder_path, batch_size, n_features):
    """
    获取一个数据生成器,从指定文件夹的.npy文件中分批返回数据。
    特征的形状为 (batch_size, n_features)。
    标签的形状为 (batch_size,),为整数形式。
    """
    def data_generator():
        files = os.listdir(folder_path)
        npy_files = [f for f in files if f.endswith('.npy')]

        for npy_file in npy_files:
            data = np.load(os.path.join(folder_path, npy_file))
            x = data[:, :n_features]
            y_ohe = data[:, n_features:]
            y_int = np.argmax(y_ohe, axis=1) # 将One-Hot编码转换为整数标签

            for i in range(0, len(x), batch_size):
                yield x[i:i+batch_size], y_int[i:i+batch_size]

    return data_generator

train_data_folder = '/home/my_user_name/original_data/train_data_npy'
validation_data_folder = '/home/my_user_name/original_data/valid_data_npy'

# 创建训练数据集,标签为1D整数
train_dataset = tf.data.Dataset.from_generator(
    get_data_generator(train_data_folder, BATCH_SIZE, N_FEATURES),
    output_signature=(
        tf.TensorSpec(shape=(None, N_FEATURES), dtype=tf.float32),
        tf.TensorSpec(shape=(None,), dtype=tf.int32) # 标签现在是1D整数
    )
)

# 创建验证数据集,标签为1D整数
validation_dataset = tf.data.Dataset.from_generator(
    get_data_generator(validation_data_folder, BATCH_SIZE, N_FEATURES),
    output_signature=(
        tf.TensorSpec(shape=(None, N_FEATURES), dtype=tf.float32),
        tf.TensorSpec(shape=(None,), dtype=tf.int32) # 标签现在是1D整数
    )
)

# 初始化分类器,并设置随机种子和覆盖模式
clf = ak.StructuredDataClassifier(overwrite=True, max_trials=10, seed=random_seed)

# 训练分类器
clf.fit(train_dataset, epochs=100)

# 评估模型
print("Model evaluation results:", clf.evaluate(validation_dataset))

# 导出并保存模型 (可选)
model = clf.export_model()
model.save("heca_v2_model_reproducible", save_format='tf')

总结

当Autokeras模型在不同运行中表现出显著性能差异时,即使标签编码方式看似合理,其根本原因也往往是随机种子未被妥善管理。Autokeras的StructuredDataClassifier能够内部处理整数标签并进行One-Hot转换,因此直接提供One-Hot编码的标签通常不是性能低下的直接原因。通过在脚本开头全局设置随机种子、在分类器初始化时指定种子并设置overwrite=True,可以有效地提高模型训练的复现性。此外,适当地增加max_trials参数,以及始终将One-Hot编码的标签转换为整数形式再输入模型,是构建稳定、可信赖的AutoML工作流的最佳实践。

相关文章

数码产品性能查询
数码产品性能查询

该软件包括了市面上所有手机CPU,手机跑分情况,电脑CPU,电脑产品信息等等,方便需要大家查阅数码产品最新情况,了解产品特性,能够进行对比选择最具性价比的商品。

下载

本站声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

相关专题

更多
页面置换算法
页面置换算法

页面置换算法是操作系统中用来决定在内存中哪些页面应该被换出以便为新的页面提供空间的算法。本专题为大家提供页面置换算法的相关文章,大家可以免费体验。

383

2023.08.14

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

5

2025.12.22

PHP 命令行脚本与自动化任务开发
PHP 命令行脚本与自动化任务开发

本专题系统讲解 PHP 在命令行环境(CLI)下的开发与应用,内容涵盖 PHP CLI 基础、参数解析、文件与目录操作、日志输出、异常处理,以及与 Linux 定时任务(Cron)的结合使用。通过实战示例,帮助开发者掌握使用 PHP 构建 自动化脚本、批处理工具与后台任务程序 的能力。

21

2025.12.13

ip地址修改教程大全
ip地址修改教程大全

本专题整合了ip地址修改教程大全,阅读下面的文章自行寻找合适的解决教程。

35

2025.12.26

压缩文件加密教程汇总
压缩文件加密教程汇总

本专题整合了压缩文件加密教程,阅读专题下面的文章了解更多详细教程。

18

2025.12.26

wifi无ip分配
wifi无ip分配

本专题整合了wifi无ip分配相关教程,阅读专题下面的文章了解更多详细教程。

46

2025.12.26

漫蛙漫画入口网址
漫蛙漫画入口网址

本专题整合了漫蛙入口网址大全,阅读下面的文章领取更多入口。

94

2025.12.26

b站看视频入口合集
b站看视频入口合集

本专题整合了b站哔哩哔哩相关入口合集,阅读下面的文章查看更多入口。

289

2025.12.26

俄罗斯搜索引擎yandex入口汇总
俄罗斯搜索引擎yandex入口汇总

本专题整合了俄罗斯搜索引擎yandex相关入口合集,阅读下面的文章查看更多入口。

372

2025.12.26

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Go 教程
Go 教程

共32课时 | 3万人学习

Go语言实战之 GraphQL
Go语言实战之 GraphQL

共10课时 | 0.8万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号