
在深度学习模型训练中,torch.utils.data.dataloader是pytorch提供的一个强大工具,用于高效地加载数据。它通常与dataset结合使用,负责数据的批处理、打乱和多进程加载等任务。最常见的用法是指定一个固定的batch_size参数:
import torch
from torch.utils.data import TensorDataset, DataLoader
# 示例数据
x_train = torch.randn(8400, 4)
y_train = torch.randint(0, 2, (8400,))
train_dataset = TensorDataset(x_train, y_train)
# 使用固定批次大小的DataLoader
dataloader_train = DataLoader(train_dataset, batch_size=64, shuffle=True)
# 迭代DataLoader
for batch_idx, (data, target) in enumerate(dataloader_train):
print(f"Batch {batch_idx}: data shape {data.shape}, target shape {target.shape}")
if batch_idx == 2: # 仅打印前3个批次
break这种方法简单直接,适用于大多数场景。然而,在某些特定的训练策略中,我们可能需要根据训练阶段、模型状态或数据特性来动态调整批次大小,例如:
PyTorch的DataLoader支持通过sampler或batch_sampler参数来完全控制批次中样本的索引选择。这是实现动态批次大小的关键。
对于动态批次大小的需求,由于我们希望直接指定每个批次的大小(即每个批次包含多少个样本),因此自定义一个生成批次索引的采样器(更接近BatchSampler的功能)是最佳选择。
我们将创建一个名为VariableBatchSampler的类,它继承自torch.utils.data.Sampler,但其行为更像一个BatchSampler,直接返回批次的索引列表。
import torch
from torch.utils.data import TensorDataset, DataLoader, Sampler
class VariableBatchSampler(Sampler):
"""
一个自定义采样器,根据预定义的批次大小列表生成批次索引。
它将数据集按顺序切片,形成指定大小的批次。
"""
def __init__(self, dataset_len: int, batch_sizes: list):
"""
初始化VariableBatchSampler。
Args:
dataset_len (int): 数据集的总长度。
batch_sizes (list): 一个整数列表,每个元素代表一个批次的样本数量。
所有批次大小之和应等于或大于dataset_len。
"""
if not isinstance(dataset_len, int) or dataset_len <= 0:
raise ValueError("dataset_len 必须是正整数。")
if not isinstance(batch_sizes, list) or not all(isinstance(bs, int) and bs > 0 for bs in batch_sizes):
raise ValueError("batch_sizes 必须是包含正整数的列表。")
if sum(batch_sizes) < dataset_len:
print(f"警告:批次大小总和 ({sum(batch_sizes)}) 小于数据集长度 ({dataset_len}),部分数据可能不会被加载。")
self.dataset_len = dataset_len
self.batch_sizes = batch_sizes
self.current_batch_idx = 0 # 当前批次在batch_sizes列表中的索引
self.current_start_idx = 0 # 当前批次在数据集中的起始索引
def __iter__(self):
"""
使采样器成为一个迭代器。每次新的迭代开始时,重置状态。
"""
self.current_batch_idx = 0
self.current_start_idx = 0
return self
def __next__(self):
"""
生成下一个批次的索引。
"""
# 如果已经遍历完所有批次或超出了数据集长度,则停止迭代
if self.current_start_idx >= self.dataset_len or \
self.current_batch_idx >= len(self.batch_sizes):
raise StopIteration()
# 获取当前批次的大小
current_batch_size = self.batch_sizes[self.current_batch_idx]
# 计算当前批次的结束索引
current_end_idx = min(self.current_start_idx + current_batch_size, self.dataset_len)
# 生成批次索引
batch_indices = torch.arange(self.current_start_idx, current_end_idx, dtype=torch.long)
# 更新状态,为下一个批次做准备
self.current_start_idx = current_end_idx
self.current_batch_idx += 1
return batch_indices.tolist() # DataLoader期望的是Python列表代码解析:
现在,我们将这个自定义采样器与DataLoader结合使用。
# 示例数据
x_train = torch.randn(8400, 4)
y_train = torch.randint(0, 2, (8400,))
train_dataset = TensorDataset(x_train, y_train)
# 定义动态批次大小列表
# 注意:这些批次大小的总和不一定需要精确等于数据集长度,
# 我们的采样器会处理最后可能不足一个完整批次的情况。
list_batch_size = [30, 60, 110, 200, 50, 150, 90, 120, 70, 180] * 20 # 假设有20个这样的循环
# 确保批次大小总和足够覆盖数据集,或者让DataLoader处理剩余部分
if sum(list_batch_size) < len(train_dataset):
print("警告:提供的批次大小总和小于数据集长度,部分数据可能不会被加载。")
# 可以选择在末尾添加一个批次以覆盖剩余数据
# list_batch_size.append(len(train_dataset) - sum(list_batch_size))
# 实例化自定义采样器
variable_sampler = VariableBatchSampler(dataset_len=len(train_dataset), batch_sizes=list_batch_size)
# 将采样器传递给DataLoader
# 推荐使用 batch_sampler 参数
data_loader_dynamic = DataLoader(train_dataset, batch_sampler=variable_sampler, num_workers=0) # num_workers=0 for simplicity
print(f"\n使用动态批次大小的DataLoader (通过 batch_sampler):")
for batch_idx, (data, target) in enumerate(data_loader_dynamic):
print(f"Batch {batch_idx}: data shape {data.shape}, target shape {target.shape}")
if batch_idx >= 15: # 仅打印前16个批次
break
print(f"总共生成了 {batch_idx + 1} 个批次。")使用 batch_sampler 的优势:
当你的自定义采样器(如VariableBatchSampler)已经直接返回批次的索引列表时,将其作为DataLoader的batch_sampler参数传递是更推荐的做法。
批次大小总和与数据集长度:确保batch_sizes列表中所有元素的总和能够覆盖整个数据集。如果总和小于数据集长度,那么部分数据将不会被模型训练到。如果总和大于数据集长度,VariableBatchSampler会自然地在达到dataset_len时停止。
数据打乱(Shuffling):我们当前的VariableBatchSampler是按顺序生成批次的。如果需要在每个epoch开始时打乱数据,你需要修改采样器:
# 示例:带有打乱功能的VariableBatchSampler (概念性代码)
class ShuffledVariableBatchSampler(Sampler):
def __init__(self, dataset_len: int, batch_sizes: list):
# ... (同上)
self.dataset_len = dataset_len
self.batch_sizes = batch_sizes
self.shuffled_indices = None # 用于存储打乱后的索引
def __iter__(self):
self.current_batch_idx = 0
self.current_start_idx = 0
# 在每个epoch开始时打乱索引
self.shuffled_indices = torch.randperm(self.dataset_len).tolist()
return self
def __next__(self):
if self.current_start_idx >= self.dataset_len or \
self.current_batch_idx >= len(self.batch_sizes):
raise StopIteration()
current_batch_size = self.batch_sizes[self.current_batch_idx]
# 从打乱的索引中获取批次
batch_indices_in_shuffled = self.shuffled_indices[self.current_start_idx : self.current_start_idx + current_batch_size]
self.current_start_idx += len(batch_indices_in_shuffled)
self.current_batch_idx += 1
return batch_indices_in_shuffleddrop_last参数:当使用batch_sampler时,DataLoader的drop_last参数会被忽略,因为批次的构成完全由batch_sampler控制。如果需要丢弃最后一个不完整的批次,你的VariableBatchSampler需要在生成批次索引时自行判断并处理。在我们的实现中,min(..., self.dataset_len)确保了即使最后一个批次不足指定大小,也会包含所有剩余数据。
通过自定义torch.utils.data.Sampler或更具体地使用batch_sampler参数,我们可以灵活地控制PyTorch DataLoader的批次大小,以适应各种复杂的训练策略。VariableBatchSampler提供了一个实现动态、非固定批次大小的有效范例,它通过直接管理批次索引的生成,赋予了用户对数据加载过程的精细控制。在实际应用中,应根据具体需求考虑是否需要结合数据打乱功能。
以上就是PyTorch DataLoader动态批次大小管理指南的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号