0

0

PyTorch ResNet50模型导出ONNX时如何解决动态batch_size难题?

DDD

DDD

发布时间:2025-03-20 09:40:02

|

330人浏览过

|

来源于php中文网

原创

解决pytorch resnet50模型导出onnx时动态batch_size难题

本文介绍如何将基于ResNet50的PyTorch模型导出为ONNX格式,重点解决动态batch_size导致的导出问题。原始代码中,imageretrievalnet类和gem类存在一些与ONNX导出不兼容的因素,主要包括gem类中可学习参数self.p以及imageretrievalnet类中未使用的self.lwhiten属性。这些动态元素阻碍了ONNX的shape推断,导致导出失败。

PyTorch ResNet50模型导出ONNX时如何解决动态batch_size难题?

为了解决这个问题,我们需要修改这两个类以适应ONNX导出流程。具体修改如下:

首先,修改gem类,将self.p参数改为直接赋值的常量,不再作为可学习参数:

Quillbot
Quillbot

一款AI写作润色工具,QuillBot的人工智能改写工具将提高你的写作能力。

下载
class gem(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(gem, self).__init__()
        self.p = p  # 直接赋值常量值
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return gem_op(x, p=self.p, eps=self.eps) # 使用自定义的gem_op函数,避免直接使用类名调用

然后,简化imageretrievalnet类,去除未使用的self.lwhiten属性:

class imageretrievalnet(nn.Module):
    def __init__(self, dim: int = 512):
        super(imageretrievalnet, self).__init__()
        resnet50_model = models.resnet50()
        features = list(resnet50_model.children())[:-2]
        self.features = nn.Sequential(*features)
        self.pool = gem()
        self.whiten = nn.Linear(2048, dim, bias=True) # 使用nn.Linear
        self.norm = l2n()

    def forward(self, x: torch.Tensor):
        o: torch.Tensor = self.features(x)
        pooled_t = self.pool(o)
        normed_t: torch.Tensor = self.norm(pooled_t)
        o: torch.Tensor = normed_t.squeeze(-1).squeeze(-1)

        if self.whiten is not None:
            whitened_t = self.whiten(o)
            normed_t: torch.Tensor = self.norm(whitened_t)
            o = normed_t

        return o.permute(1, 0)

通过以上修改,消除了动态参数带来的不确定性,使ONNX导出能够顺利进行。 使用修改后的imageretrievalnet类,并利用torch.onnx.export函数,指定dynamic_axes参数处理动态batch_size,即可成功导出ONNX模型:

model = imageretrievalnet()
batch_size = 4
input_shape = (batch_size, 3, 224, 224)
input_data = torch.randn(input_shape)
torch.onnx.export(
    model,
    input_data,
    "resnet50.onnx",
    input_names=["input"], output_names=["output"],
    opset_version=12,
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)

记住根据实际情况调整opset_version参数。 通过这些修改,即可成功导出支持动态batch_size的ResNet50 ONNX模型。 请注意,代码中添加了gem_op函数的假设,该函数应该实现gem类的功能,以避免在ONNX导出过程中直接使用类名调用。

相关标签:

本站声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

相关专题

更多
java基础知识汇总
java基础知识汇总

java基础知识有Java的历史和特点、Java的开发环境、Java的基本数据类型、变量和常量、运算符和表达式、控制语句、数组和字符串等等知识点。想要知道更多关于java基础知识的朋友,请阅读本专题下面的的有关文章,欢迎大家来php中文网学习。

1463

2023.10.24

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

429

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

19

2025.12.22

c++主流开发框架汇总
c++主流开发框架汇总

本专题整合了c++开发框架推荐,阅读专题下面的文章了解更多详细内容。

97

2026.01.09

c++框架学习教程汇总
c++框架学习教程汇总

本专题整合了c++框架学习教程汇总,阅读专题下面的文章了解更多详细内容。

51

2026.01.09

学python好用的网站推荐
学python好用的网站推荐

本专题整合了python学习教程汇总,阅读专题下面的文章了解更多详细内容。

139

2026.01.09

学python网站汇总
学python网站汇总

本专题整合了学python网站汇总,阅读专题下面的文章了解更多详细内容。

12

2026.01.09

python学习网站
python学习网站

本专题整合了python学习相关推荐汇总,阅读专题下面的文章了解更多详细内容。

19

2026.01.09

俄罗斯手机浏览器地址汇总
俄罗斯手机浏览器地址汇总

汇总俄罗斯Yandex手机浏览器官方网址入口,涵盖国际版与俄语版,适配移动端访问,一键直达搜索、地图、新闻等核心服务。

83

2026.01.09

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号