
提升PyTorch DataLoader效率:避免重复实例化
在PyTorch深度学习训练中,高效的数据加载至关重要。 反复创建DataLoader实例会导致进程池的重复创建和销毁,严重影响训练速度。本文介绍如何复用DataLoader,避免这种低效的重复实例化操作。
问题:许多代码在每次迭代中都重新创建DataLoader:DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)。 这会造成性能瓶颈,因为DataLoader初始化需要创建进程池,频繁地创建和销毁进程池会消耗大量资源。
解决方案:将DataLoader的创建移至训练循环之外。 只需在训练开始前创建一次DataLoader实例,并在训练循环中重复使用它即可。 以下代码演示了改进后的方法:
import torch
from torch.utils.data import DataLoader, Dataset
from math import sqrt
from typing import List, Tuple, Union
from numpy import ndarray
from PIL import Image
from torchvision import transforms
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
class PreprocessImageDataset(Dataset):
def __init__(self, images: Union[List[ndarray], Tuple[ndarray]]):
self.images = images
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
image = Image.fromarray(image)
preprocessed_image: torch.Tensor = preprocess(image)
unsqueezed_image = preprocessed_image
return unsqueezed_image
if __name__=='__main__':
data = list(range(10000000))
batch_size = 10
num_workers = 16
dataset = PreprocessImageDataset(data)
dataloader = DataLoader(dataset, batch_size=batch_size,
num_workers=num_workers)
for epoch in range(5):
print(f"Epoch {epoch + 1}:")
for batch_data in dataloader:
batch_data
print("Batch data:", batch_data)
print("Batch data type :", type(batch_data))
print("Batch data shape:", batch_data.shape)通过将DataLoader的实例化放在循环外,并在多个epoch中复用同一个实例,我们避免了重复创建进程池,显著提高了数据加载效率,减少了系统开销,从而提升了训练性能。
以上就是PyTorch DataLoader 如何避免重复实例化以提升训练效率?的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号