理解 Transformers 中的交叉熵损失及 Masked Label 问题

花韻仙語
发布: 2025-10-01 10:51:02
原创
685人浏览过

理解 transformers 中的交叉熵损失及 masked label 问题

本文旨在深入解析 Hugging Face Transformers 库中,使用 GPT-2 等 Decoder-Only 模型计算交叉熵损失时,如何正确使用 masked label,并解释了常见的困惑。通过具体示例和代码,详细阐述了 target_ids 的构建方法,以及如何结合 ignore_index 来控制损失计算的范围,从而避免不必要的计算偏差,并提供了手动计算损失的替代方案。

在使用 Hugging Face Transformers 库进行自然语言处理任务时,尤其是使用 GPT-2 等 Decoder-Only 模型时,理解交叉熵损失的计算方式和 masked label 的作用至关重要。本文将深入探讨 target_ids 的正确构建方法,以及如何利用 ignore_index 来精确控制损失计算的范围,从而避免常见的错误和困惑。

Decoder-Only 模型、输入和目标

在 Hugging Face Transformers 库中,Decoder-Only 模型(如 GPT-2)主要依赖 input_ids、label_ids 和 attention_mask 进行训练。其中,input_ids 代表输入序列的 token IDs,label_ids 代表目标序列的 token IDs,而 attention_mask 用于指示哪些 token 应该被模型关注。

假设我们有一个输入 "The answer is:",我们希望模型学习回答 "42"。将这个句子转化为 token IDs,假设 "The answer is: 42" 对应的 IDs 是 [464, 3280, 318, 25, 5433](其中 ":" 是 25," 42" 是 5433)。

为了让模型学习预测 "42",我们需要设置 label_ids 为 [-100, -100, -100, -100, 5433]。这样,模型就不会学习到 "The answer" 后面应该跟着 "is:",因为这些位置的损失被忽略了。

注意: Decoder-Only 模型要求输入和输出具有相同的形状。这与 Encoder-Decoder 模型不同,后者可以接受 "The answer is:" 作为输入,而 "42" 作为输出。

-100 是 torch.nn.CrossEntropyLoss 的默认 ignore_index。使用 "忽略" 比 "mask" 更准确,因为 "mask" 暗示模型看不到这些输入,或者原始输入被替换为特殊的 "<masked>" token。

理解问题的根源

原始问题中,代码 target_ids[:, :-seq_len] = -100 试图将 target_ids 中除了最后 seq_len 个元素之外的所有元素设置为 -100。然而,由于 target_ids 的长度为 seq_len,所以实际上没有任何元素被修改,导致损失计算结果不变。

迭代数据集时的正确方法

在使用滑动窗口迭代数据集时,masked label 的应用需要在不同的迭代步骤中进行调整。以下是一个示例:

AI建筑知识问答
AI建筑知识问答

用人工智能ChatGPT帮你解答所有建筑问题

AI建筑知识问答 22
查看详情 AI建筑知识问答

第一次迭代:

max_length = 1024
stride = 512

end_loc = 1024
input_ids = tokens[0 : 1024]
target_ids = input_ids.clone()
target_ids[:-1024] = -100  # 实际上没有修改任何元素

assert torch.equal(target_ids, input_ids)

trg_len = 1024
prev_end_loc = 1024
登录后复制

在第一次迭代中,由于 target_ids[:-1024] 实际上等于 target_ids[:0],因此 target_ids 没有被修改,损失是基于所有 1024 个 token 计算的。

第二次及后续迭代:

begin_loc = 512
end_loc = 1536
trg_len = 1536 - 1024  # 512

input_ids = tokens[512 : 1536]  # 注意:tokens 512-1024 已经被模型看到过
target_ids = tokens[512 : 1536].clone()
target_ids[:-512] = -100  # 将已经见过的 token 对应的 label 设置为 -100
登录后复制

从第二次迭代开始,target_ids 的前 512 个元素(对应于模型已经见过的 token)被设置为 -100,损失仅基于后 512 个 token 计算。

手动计算损失

如果需要更精细地控制损失计算过程,可以直接从模型获取 logits,然后手动计算交叉熵损失。

from torch.nn import CrossEntropyLoss

outputs = model(encodings.input_ids, labels=None)

logits = outputs.logits
labels = target_ids.to(logits.device)

# 调整 logits 和 labels 的形状,使其匹配
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# 计算损失
loss_fct = CrossEntropyLoss(reduction='mean')
loss = loss_fct(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1))

print(loss.item())
登录后复制

这段代码首先从模型获取 logits,然后将 logits 和 labels 的形状进行调整,使其能够匹配。最后,使用 CrossEntropyLoss 计算损失。

总结:

理解 Decoder-Only 模型中 target_ids 的构建方式,以及如何利用 ignore_index 来控制损失计算的范围,是使用 Hugging Face Transformers 库进行自然语言处理任务的关键。通过正确设置 target_ids,可以避免不必要的计算偏差,并提高模型的训练效果。

以上就是理解 Transformers 中的交叉熵损失及 Masked Label 问题的详细内容,更多请关注php中文网其它相关文章!

最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号