
本文旨在详细阐述如何将Vision Transformer(ViT)从单标签多分类任务转换为多标签分类任务,并重点介绍损失函数的选择与评估策略的调整。我们将探讨为何`CrossEntropyLoss`不适用于多标签场景,并深入讲解`BCEWithLogitsLoss`的使用方法,包括标签格式要求。此外,文章还将介绍多标签分类任务中常用的评估指标,如精确率、召回率、F1分数和mAP,并提供代码示例,确保读者能够顺利实现ViT在多标签环境下的训练与评估。
在深度学习的图像分类任务中,单标签多分类(Single-label Multi-class Classification)是指每张图片只属于一个类别,模型需要从多个互斥的类别中预测出唯一正确的那个。而多标签分类(Multi-label Classification)则允许每张图片同时属于一个或多个类别,模型需要为每个类别独立地判断其是否存在于图片中。
这种任务性质的转变,要求我们对模型的输出层、损失函数以及评估策略进行相应的调整。对于Vision Transformer(ViT)而言,其特征提取部分通常保持不变,但最终的分类头和训练流程需要进行适配。
在单标签多分类任务中,我们通常使用torch.nn.CrossEntropyLoss作为损失函数。它内部包含了Softmax激活函数和负对数似然损失,期望模型的输出是每个类别的Logits,并且这些Logits经过Softmax后会转化为概率分布,所有类别的概率和为1。
然而,在多标签分类任务中,由于图片可能同时属于多个类别,各个类别之间不再是互斥关系。因此,CrossEntropyLoss不再适用,因为它强制了类别之间的互斥性。
推荐的损失函数:torch.nn.BCEWithLogitsLoss
对于多标签分类任务,最常用且推荐的损失函数是torch.nn.BCEWithLogitsLoss。这个损失函数结合了Sigmoid激活函数和二元交叉熵损失(Binary Cross Entropy Loss)。
其主要优点包括:
使用BCEWithLogitsLoss的注意事项:
代码示例:替换损失函数
假设我们有一个ViT模型,其输出为pred(Logits),标签为labels(多热编码)。
import torch
import torch.nn as nn
# 假设模型输出的Logits,形状为 (batch_size, num_classes)
# 这里以 batch_size = 2, num_classes = 7 为例
logits = torch.randn(2, 7) # 模拟模型输出的原始Logits
# 假设对应的多标签,形状也为 (batch_size, num_classes)
# 注意:标签必须是浮点型 (torch.float)
labels = torch.tensor([
[0, 1, 1, 0, 0, 1, 0], # 第一个样本的标签
[1, 0, 1, 1, 0, 0, 0] # 第二个样本的标签
]).float()
# 实例化 BCEWithLogitsLoss
loss_function = nn.BCEWithLogitsLoss()
# 计算损失
loss = loss_function(logits, labels)
print(f"Logits:\n{logits}")
print(f"Labels:\n{labels}")
print(f"Calculated Loss: {loss.item()}")
# 原始训练循环中的应用
# pred = model(images.to(device))
# loss = loss_function(pred, labels.to(device))
# loss.backward()
# optimizer.step()在单标签分类中,准确率(Accuracy)是最常用的评估指标。然而,在多标签分类中,仅仅计算准确率是不足够的,甚至可能产生误导。例如,如果一个模型总是预测所有类别都不存在,而实际只有少数类别存在,那么它的准确率可能很高(因为它正确预测了大量不存在的类别),但它对存在类别的识别能力却很差。
因此,我们需要采用更全面的指标来评估多标签分类模型的性能。
1. 从Logits到预测结果
在计算评估指标之前,我们需要将模型的Logits输出转换为具体的类别预测。这通常通过对Logits应用Sigmoid函数,然后设定一个阈值(例如0.5)来完成。
# 假设 logits 是模型输出的Logits
# 例如:logits = torch.randn(batch_size, num_classes)
# 1. 应用Sigmoid函数将Logits转换为概率
probabilities = torch.sigmoid(logits)
# 2. 设定阈值,将概率转换为二元预测 (0或1)
threshold = 0.5
predictions = (probabilities > threshold).float()
print(f"Probabilities:\n{probabilities}")
print(f"Predictions (threshold={threshold}):\n{predictions}")2. 常用评估指标
以下是多标签分类中常用的评估指标:
精确率(Precision)、召回率(Recall)、F1分数(F1-score):
平均精确率(Average Precision, AP)与平均精确率均值(mean Average Precision, mAP):
Jaccard Index (IoU) / Jaccard Similarity Score:
Hamming Loss:
3. 使用torchmetrics或scikit-learn进行评估
在PyTorch生态中,torchmetrics库提供了丰富的多标签评估指标。scikit-learn也是一个非常强大的工具,可以在CPU上方便地进行评估。
torchmetrics示例 (推荐用于PyTorch训练循环中):
import torch
from torchmetrics.classification import MultilabelF1Score, MultilabelAveragePrecision
# 假设真实标签和预测概率
# num_classes = 7
num_labels = 7
num_samples = 10
target_labels = torch.randint(0, 2, (num_samples, num_labels)).float() # 真实标签 (0或1)
predicted_probs = torch.rand(num_samples, num_labels) # 模型输出的概率 (经过Sigmoid)
# 或者直接使用Logits,让metrics内部处理Sigmoid
predicted_logits = torch.randn(num_samples, num_labels)
# 实例化F1分数,可以指定 average 方式 (e.g., 'micro', 'macro', 'weighted', 'none')
# MultilabelF1Score 期望输入是 (preds, target)
# preds: 概率 (float) 或 原始logits (float)
# target: 真实标签 (int 或 float, 0/1)
f1_score_micro = MultilabelF1Score(num_labels=num_labels, average='micro', validate_args=False)
f1_score_macro = MultilabelF1Score(num_labels=num_labels, average='macro', validate_args=False)
# 计算F1分数
# 注意:MultilabelF1Score 可以直接接收概率或logits,但通常建议给概率
f1_micro_val = f1_score_micro(predicted_probs, target_labels.long()) # target_labels需要是long类型对于F1Score
f1_macro_val = f1_score_macro(predicted_probs, target_labels.long())
print(f"Micro F1 Score: {f1_micro_val.item()}")
print(f"Macro F1 Score: {f1_macro_val.item()}")
# 实例化mAP
# MultilabelAveragePrecision 期望输入是 (preds, target)
# preds: 概率 (float)
# target: 真实标签 (int 或 float, 0/1)
map_metric = MultilabelAveragePrecision(num_labels=num_labels, validate_args=False)
# 计算mAP
map_val = map_metric(predicted_probs, target_labels.long()) # target_labels需要是long类型对于mAP
print(f"mAP: {map_val.item()}")
# 如果输入是logits,可以这样处理 (MultilabelF1Score 和 MultilabelAveragePrecision 默认不带sigmoid,需要手动处理或确保其内部处理了)
# 对于MultilabelF1Score和MultilabelAveragePrecision,当输入是概率时,通常需要手动将target转换为long
# 如果输入是logits,则需要确保metrics内部会执行sigmoid
# 更好的做法是,统一将模型输出转换为概率再传入metrics
probs_from_logits = torch.sigmoid(predicted_logits)
f1_micro_val_logits = f1_score_micro(probs_from_logits, target_labels.long())
map_val_logits = map_metric(probs_from_logits, target_labels.long())
print(f"Micro F1 Score (from logits): {f1_micro_val_logits.item()}")
print(f"mAP (from logits): {map_val_logits.item()}")将ViT从单标签多分类转换为多标签分类,关键在于以下几点:
通过以上调整,您的Vision Transformer模型将能够有效地处理多标签图像分类任务。
以上就是Vision Transformer多标签分类:损失函数与评估策略深度解析的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号