0

0

PyTorch DataLoader 自定义 Sampler 迭代问题解决

DDD

DDD

发布时间:2025-10-18 08:01:20

|

727人浏览过

|

来源于php中文网

原创

pytorch dataloader 自定义 sampler 迭代问题解决

本文针对 PyTorch 中使用自定义 Sampler 时,DataLoader 只能迭代一个 epoch 的问题进行了分析和解决。通过修改 Sampler 的 `__next__` 方法,在抛出 `StopIteration` 异常时重置索引,使得 DataLoader 可以在多个 epoch 中正常迭代。文章提供了一个完整的代码示例,演示了如何实现一个可以根据不同 batch size 采样数据的自定义 Sampler,并确保其在训练循环中正常工作。

在使用 PyTorch 进行深度学习模型训练时,DataLoader 是一个非常重要的工具,它负责数据的加载和预处理。DataLoader 可以与 Sampler 结合使用,以控制数据的采样方式。然而,当使用自定义的 Sampler 时,可能会遇到 DataLoader 只能迭代一个 epoch 的问题。这通常是由于 Sampler 在一个 epoch 结束后没有正确地重置其内部状态导致的。

问题分析

当 DataLoader 迭代 Sampler 时,它会不断调用 Sampler 的 __next__ 方法来获取下一个 batch 的索引。当 Sampler 完成一次完整的数据集遍历后,它应该抛出一个 StopIteration 异常来通知 DataLoader 停止迭代。然而,如果 Sampler 在抛出 StopIteration 异常后没有重置其内部索引,那么在下一个 epoch 开始时,Sampler 仍然处于完成状态,导致 DataLoader 无法继续迭代。

解决方案

解决这个问题的方法是在 Sampler 的 __next__ 方法中,当检测到数据集已经遍历完毕并准备抛出 StopIteration 异常时,同时重置 Sampler 的内部索引。

下面是一个示例,展示了如何修改一个自定义的 Sampler 来解决这个问题。假设我们有一个 VariableBatchSampler,它可以根据预定义的 batch_sizes 列表来生成不同大小的 batch。

import torch
import numpy as np
from torch.utils.data import Sampler
from torch.utils.data import DataLoader, TensorDataset



class VariableBatchSampler(Sampler):
    def __init__(self, dataset_len: int, batch_sizes: list):
        self.dataset_len = dataset_len
        self.batch_sizes = batch_sizes
        self.batch_idx = 0
        self.start_idx = 0
        self.end_idx = self.batch_sizes[self.batch_idx]

    def __iter__(self):
        return self

    def __next__(self):
        if self.start_idx >= self.dataset_len:
            self.batch_idx = 0
            self.start_idx = 0
            self.end_idx = self.batch_sizes[self.batch_idx]
            raise StopIteration

        batch_indices = list(range(self.start_idx, self.end_idx))
        self.start_idx = self.end_idx
        self.batch_idx += 1

        try:
            self.end_idx += self.batch_sizes[self.batch_idx]
        except IndexError:
            self.end_idx = self.dataset_len

        return batch_indices

在这个 VariableBatchSampler 中,我们在 __next__ 方法中添加了以下代码:

比话降AI
比话降AI

清除AIGC痕迹,AI率降低至15%

下载
if self.start_idx >= self.dataset_len:
    self.batch_idx = 0
    self.start_idx = 0
    self.end_idx = self.batch_sizes[self.batch_idx]
    raise StopIteration

这段代码在 self.start_idx 大于或等于 self.dataset_len 时执行,这意味着我们已经遍历了整个数据集。此时,我们将 self.batch_idx、self.start_idx 和 self.end_idx 重置为初始值,以便在下一个 epoch 中重新开始迭代。

完整示例

下面是一个完整的示例,展示了如何使用修改后的 VariableBatchSampler 和 DataLoader 进行多 epoch 训练。

import torch
import numpy as np
from torch.utils.data import Sampler
from torch.utils.data import DataLoader, TensorDataset



class VariableBatchSampler(Sampler):
    def __init__(self, dataset_len: int, batch_sizes: list):
        self.dataset_len = dataset_len
        self.batch_sizes = batch_sizes
        self.batch_idx = 0
        self.start_idx = 0
        self.end_idx = self.batch_sizes[self.batch_idx]

    def __iter__(self):
        return self

    def __next__(self):
        if self.start_idx >= self.dataset_len:
            self.batch_idx = 0
            self.start_idx = 0
            self.end_idx = self.batch_sizes[self.batch_idx]
            raise StopIteration

        batch_indices = list(range(self.start_idx, self.end_idx))
        self.start_idx = self.end_idx
        self.batch_idx += 1

        try:
            self.end_idx += self.batch_sizes[self.batch_idx]
        except IndexError:
            self.end_idx = self.dataset_len

        return batch_indices


x_train = torch.randn(23)
y_train = torch.randint(0, 2, (23,))

batch_sizes = [4, 10, 7, 2]
train_dataset = TensorDataset(x_train, y_train)
sampler = VariableBatchSampler(dataset_len=len(x_train), batch_sizes=batch_sizes)
dataloader_train = DataLoader(train_dataset, sampler=sampler)

max_epoch = 4
for epoch in np.arange(1, max_epoch):
    print("Epoch: ", epoch)
    for x_batch, y_batch in dataloader_train:
         print(x_batch.shape)

这段代码会输出每个 epoch 中每个 batch 的形状,证明 DataLoader 可以在多个 epoch 中正常迭代。

总结

当使用自定义的 Sampler 时,确保在 __next__ 方法中正确地重置内部索引,以便 DataLoader 可以在多个 epoch 中正常迭代。 否则,DataLoader 在第一个epoch后会停止工作。 通过本文提供的示例,您可以更好地理解如何实现一个自定义的 Sampler,并解决 DataLoader 迭代问题。

相关专题

更多
pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

428

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

12

2025.12.22

php源码安装教程大全
php源码安装教程大全

本专题整合了php源码安装教程,阅读专题下面的文章了解更多详细内容。

127

2025.12.31

php网站源码教程大全
php网站源码教程大全

本专题整合了php网站源码相关教程,阅读专题下面的文章了解更多详细内容。

75

2025.12.31

视频文件格式
视频文件格式

本专题整合了视频文件格式相关内容,阅读专题下面的文章了解更多详细内容。

81

2025.12.31

不受国内限制的浏览器大全
不受国内限制的浏览器大全

想找真正自由、无限制的上网体验?本合集精选2025年最开放、隐私强、访问无阻的浏览器App,涵盖Tor、Brave、Via、X浏览器、Mullvad等高自由度工具。支持自定义搜索引擎、广告拦截、隐身模式及全球网站无障碍访问,部分更具备防追踪、去谷歌化、双内核切换等高级功能。无论日常浏览、隐私保护还是突破地域限制,总有一款适合你!

60

2025.12.31

出现404解决方法大全
出现404解决方法大全

本专题整合了404错误解决方法大全,阅读专题下面的文章了解更多详细内容。

430

2025.12.31

html5怎么播放视频
html5怎么播放视频

想让网页流畅播放视频?本合集详解HTML5视频播放核心方法!涵盖<video>标签基础用法、多格式兼容(MP4/WebM/OGV)、自定义播放控件、响应式适配及常见浏览器兼容问题解决方案。无需插件,纯前端实现高清视频嵌入,助你快速打造现代化网页视频体验。

15

2025.12.31

关闭win10系统自动更新教程大全
关闭win10系统自动更新教程大全

本专题整合了关闭win10系统自动更新教程大全,阅读专题下面的文章了解更多详细内容。

11

2025.12.31

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
React 教程
React 教程

共58课时 | 3.2万人学习

Pandas 教程
Pandas 教程

共15课时 | 0.9万人学习

ASP 教程
ASP 教程

共34课时 | 3.1万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号