Vision Transformer多标签分类:损失函数与评估策略深度解析

聖光之護
发布: 2025-10-17 11:05:31
原创
247人浏览过

vision transformer多标签分类:损失函数与评估策略深度解析

本文旨在详细阐述如何将Vision Transformer(ViT)从单标签多分类任务转换为多标签分类任务,并重点介绍损失函数的选择与评估策略的调整。我们将探讨为何`CrossEntropyLoss`不适用于多标签场景,并深入讲解`BCEWithLogitsLoss`的使用方法,包括标签格式要求。此外,文章还将介绍多标签分类任务中常用的评估指标,如精确率、召回率、F1分数和mAP,并提供代码示例,确保读者能够顺利实现ViT在多标签环境下的训练与评估。

从单标签到多标签:核心概念转变

深度学习的图像分类任务中,单标签多分类(Single-label Multi-class Classification)是指每张图片只属于一个类别,模型需要从多个互斥的类别中预测出唯一正确的那个。而多标签分类(Multi-label Classification)则允许每张图片同时属于一个或多个类别,模型需要为每个类别独立地判断其是否存在于图片中。

这种任务性质的转变,要求我们对模型的输出层、损失函数以及评估策略进行相应的调整。对于Vision Transformer(ViT)而言,其特征提取部分通常保持不变,但最终的分类头和训练流程需要进行适配。

损失函数的选择与实现

在单标签多分类任务中,我们通常使用torch.nn.CrossEntropyLoss作为损失函数。它内部包含了Softmax激活函数和负对数似然损失,期望模型的输出是每个类别的Logits,并且这些Logits经过Softmax后会转化为概率分布,所有类别的概率和为1。

然而,在多标签分类任务中,由于图片可能同时属于多个类别,各个类别之间不再是互斥关系。因此,CrossEntropyLoss不再适用,因为它强制了类别之间的互斥性。

推荐的损失函数:torch.nn.BCEWithLogitsLoss

对于多标签分类任务,最常用且推荐的损失函数是torch.nn.BCEWithLogitsLoss。这个损失函数结合了Sigmoid激活函数和二元交叉熵损失(Binary Cross Entropy Loss)。

其主要优点包括:

  1. 独立处理每个类别: BCEWithLogitsLoss会对模型输出的每个Logit独立地计算二元交叉熵,这与多标签任务中各类别独立存在的特性相符。
  2. 数值稳定性: 它直接作用于模型的原始Logits输出,内部处理了Sigmoid激活,避免了先手动计算Sigmoid再计算交叉熵可能导致的数值溢出或下溢问题。

使用BCEWithLogitsLoss的注意事项:

  1. 模型输出: 模型的最终输出层应该是一个全连接层,输出维度等于类别的总数,且不应在其后接Softmax激活函数。例如,如果你的模型有7个类别,最终输出应为形状(batch_size, 7)的Logits张量。
  2. 标签格式: 标签(target)必须是与模型输出Logits形状相同的浮点型(torch.float)张量。它通常是一个“多热编码”(multi-hot encoding)向量,其中1表示该类别存在,0表示该类别不存在。例如,[0, 1, 1, 0, 0, 1, 0]表示第二个、第三个和第六个类别存在。

代码示例:替换损失函数

假设我们有一个ViT模型,其输出为pred(Logits),标签为labels(多热编码)。

import torch
import torch.nn as nn

# 假设模型输出的Logits,形状为 (batch_size, num_classes)
# 这里以 batch_size = 2, num_classes = 7 为例
logits = torch.randn(2, 7) # 模拟模型输出的原始Logits

# 假设对应的多标签,形状也为 (batch_size, num_classes)
# 注意:标签必须是浮点型 (torch.float)
labels = torch.tensor([
    [0, 1, 1, 0, 0, 1, 0], # 第一个样本的标签
    [1, 0, 1, 1, 0, 0, 0]  # 第二个样本的标签
]).float()

# 实例化 BCEWithLogitsLoss
loss_function = nn.BCEWithLogitsLoss()

# 计算损失
loss = loss_function(logits, labels)

print(f"Logits:\n{logits}")
print(f"Labels:\n{labels}")
print(f"Calculated Loss: {loss.item()}")

# 原始训练循环中的应用
# pred = model(images.to(device))
# loss = loss_function(pred, labels.to(device))
# loss.backward()
# optimizer.step()
登录后复制

多标签分类的评估策略

在单标签分类中,准确率(Accuracy)是最常用的评估指标。然而,在多标签分类中,仅仅计算准确率是不足够的,甚至可能产生误导。例如,如果一个模型总是预测所有类别都不存在,而实际只有少数类别存在,那么它的准确率可能很高(因为它正确预测了大量不存在的类别),但它对存在类别的识别能力却很差。

因此,我们需要采用更全面的指标来评估多标签分类模型的性能。

1. 从Logits到预测结果

百度GBI
百度GBI

百度GBI-你的大模型商业分析助手

百度GBI104
查看详情 百度GBI

在计算评估指标之前,我们需要将模型的Logits输出转换为具体的类别预测。这通常通过对Logits应用Sigmoid函数,然后设定一个阈值(例如0.5)来完成。

# 假设 logits 是模型输出的Logits
# 例如:logits = torch.randn(batch_size, num_classes)

# 1. 应用Sigmoid函数将Logits转换为概率
probabilities = torch.sigmoid(logits)

# 2. 设定阈值,将概率转换为二元预测 (0或1)
threshold = 0.5
predictions = (probabilities > threshold).float()

print(f"Probabilities:\n{probabilities}")
print(f"Predictions (threshold={threshold}):\n{predictions}")
登录后复制

2. 常用评估指标

以下是多标签分类中常用的评估指标:

  • 精确率(Precision)、召回率(Recall)、F1分数(F1-score):

    • 精确率: 预测为正例的样本中,有多少是真正的正例。
    • 召回率: 实际为正例的样本中,有多少被模型预测为正例。
    • F1分数: 精确率和召回率的调和平均值,综合衡量模型的性能。
    • 这些指标可以针对每个类别独立计算(Per-class),也可以通过微平均(Micro-average)或宏平均(Macro-average)来汇总所有类别的结果。
      • Micro-average: 汇总所有类别的TP、FP、FN后再计算总体的Precision、Recall、F1。它更侧重于样本级别的性能,受样本数量较多的类别影响较大。
      • Macro-average: 先计算每个类别的Precision、Recall、F1,然后取这些值的平均。它给予每个类别相同的权重,不受类别样本数量不平衡的影响。
  • 平均精确率(Average Precision, AP)与平均精确率均值(mean Average Precision, mAP):

    • AP: 衡量单个类别在不同召回率下的精确率表现,通常通过计算PR曲线下面积获得。AP值越高,说明模型在该类别上的性能越好。
    • mAP: 对所有类别的AP值取平均,是衡量多标签分类模型整体性能的一个非常重要的指标,尤其在目标检测等领域广泛使用。
  • Jaccard Index (IoU) / Jaccard Similarity Score:

    • 衡量预测集合与真实标签集合的相似度,计算公式为交集大小除以并集大小。对于多标签分类,可以计算每个样本的预测标签集合与真实标签集合的Jaccard相似度,然后取平均。
  • Hamming Loss:

    • 衡量预测结果与真实标签不一致的标签比例。Hamming Loss越低越好。

3. 使用torchmetrics或scikit-learn进行评估

在PyTorch生态中,torchmetrics库提供了丰富的多标签评估指标。scikit-learn也是一个非常强大的工具,可以在CPU上方便地进行评估。

torchmetrics示例 (推荐用于PyTorch训练循环中):

import torch
from torchmetrics.classification import MultilabelF1Score, MultilabelAveragePrecision

# 假设真实标签和预测概率
# num_classes = 7
num_labels = 7
num_samples = 10
target_labels = torch.randint(0, 2, (num_samples, num_labels)).float() # 真实标签 (0或1)
predicted_probs = torch.rand(num_samples, num_labels) # 模型输出的概率 (经过Sigmoid)

# 或者直接使用Logits,让metrics内部处理Sigmoid
predicted_logits = torch.randn(num_samples, num_labels)


# 实例化F1分数,可以指定 average 方式 (e.g., 'micro', 'macro', 'weighted', 'none')
# MultilabelF1Score 期望输入是 (preds, target)
# preds: 概率 (float) 或 原始logits (float)
# target: 真实标签 (int 或 float, 0/1)
f1_score_micro = MultilabelF1Score(num_labels=num_labels, average='micro', validate_args=False)
f1_score_macro = MultilabelF1Score(num_labels=num_labels, average='macro', validate_args=False)

# 计算F1分数
# 注意:MultilabelF1Score 可以直接接收概率或logits,但通常建议给概率
f1_micro_val = f1_score_micro(predicted_probs, target_labels.long()) # target_labels需要是long类型对于F1Score
f1_macro_val = f1_score_macro(predicted_probs, target_labels.long())


print(f"Micro F1 Score: {f1_micro_val.item()}")
print(f"Macro F1 Score: {f1_macro_val.item()}")

# 实例化mAP
# MultilabelAveragePrecision 期望输入是 (preds, target)
# preds: 概率 (float)
# target: 真实标签 (int 或 float, 0/1)
map_metric = MultilabelAveragePrecision(num_labels=num_labels, validate_args=False)

# 计算mAP
map_val = map_metric(predicted_probs, target_labels.long()) # target_labels需要是long类型对于mAP

print(f"mAP: {map_val.item()}")

# 如果输入是logits,可以这样处理 (MultilabelF1Score 和 MultilabelAveragePrecision 默认不带sigmoid,需要手动处理或确保其内部处理了)
# 对于MultilabelF1Score和MultilabelAveragePrecision,当输入是概率时,通常需要手动将target转换为long
# 如果输入是logits,则需要确保metrics内部会执行sigmoid
# 更好的做法是,统一将模型输出转换为概率再传入metrics
probs_from_logits = torch.sigmoid(predicted_logits)
f1_micro_val_logits = f1_score_micro(probs_from_logits, target_labels.long())
map_val_logits = map_metric(probs_from_logits, target_labels.long())
print(f"Micro F1 Score (from logits): {f1_micro_val_logits.item()}")
print(f"mAP (from logits): {map_val_logits.item()}")
登录后复制

总结与注意事项

将ViT从单标签多分类转换为多标签分类,关键在于以下几点:

  1. 模型输出层: 确保模型的最终全连接层输出与类别数量相等的Logits,并且不带Softmax激活。
  2. 损失函数: 使用torch.nn.BCEWithLogitsLoss作为损失函数,它能独立处理每个类别的预测。
  3. 标签格式: 真实标签应为多热编码的浮点型张量,形状与模型输出的Logits相同。
  4. 评估指标: 采用适合多标签任务的评估指标,如Micro/Macro F1分数、mAP、Jaccard Index等,并结合torchmetrics或scikit-learn等库进行高效计算。
  5. 阈值选择: 在将概率转换为二元预测时,阈值的选择对最终的精确率和召回率有显著影响,可能需要通过验证集进行调优。
  6. 类别不平衡: 在多标签任务中,类别不平衡问题可能更复杂(例如,某些标签总是同时出现,某些标签非常稀有)。可以考虑使用加权BCE损失、Focal Loss或采样策略来缓解。

通过以上调整,您的Vision Transformer模型将能够有效地处理多标签图像分类任务。

以上就是Vision Transformer多标签分类:损失函数与评估策略深度解析的详细内容,更多请关注php中文网其它相关文章!

最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习
PHP中文网抖音号
发现有趣的

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号