
利用ONNX Runtime高效运行PyTorch模型
本文将指导您如何使用ONNX Runtime运行经torch.onnx.export导出的PyTorch模型,并重点解决PyTorch张量与ONNX Runtime所需NumPy数组类型不兼容的问题。
首先,我们来看一个PyTorch模型导出示例:
import torch
class SumModule(torch.nn.Module):
    def forward(self, x):
        return torch.sum(x, dim=1)
torch.onnx.export(
    SumModule(),
    (torch.ones(2, 2),),
    "onnx.pb",
    input_names=["x"],
    output_names=["sum"]
)这段代码定义了一个简单的PyTorch模型SumModule,并将其导出为名为onnx.pb的ONNX模型文件。
直接使用PyTorch张量作为ONNX Runtime的输入会导致错误,因为ONNX Runtime期望的是NumPy数组。 错误信息通常提示输入类型错误。
为了解决这个问题,我们需要将PyTorch张量转换为NumPy数组。 正确的代码如下:
import onnxruntime
import numpy as np
import torch
ort_session = onnxruntime.InferenceSession("onnx.pb")
# 关键修改:将torch.Tensor转换为np.ndarray
x = np.ones((2, 2), dtype=np.float32)
inputs = {ort_session.get_inputs()[0].name: x}
print(ort_session.run(None, inputs))这段代码加载onnx.pb文件,创建一个形状为(2, 2),数据类型为float32的NumPy数组作为模型输入。 ort_session.get_inputs()[0].name 获取输入张量的名称,确保输入数据与模型定义匹配。  ort_session.run 函数运行模型并打印输出结果。
更简洁的等效代码:
import onnxruntime as ort
import numpy as np
sess = ort.InferenceSession("onnx.pb")
input_data = np.ones((2, 2)).astype(np.float32)
output_data = sess.run(None, {"x": input_data})[0]
print(output_data)这段代码功能相同,但更简洁易读。 关键在于使用NumPy数组作为输入。
通过以上方法,您可以成功加载并运行使用torch.onnx.export导出的PyTorch模型。  请确保输入数据的类型和形状与模型的预期输入相匹配。
以上就是如何用ONNX Runtime运行PyTorch导出的模型并解决类型不兼容问题?的详细内容,更多请关注php中文网其它相关文章!
 
                        
                        每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
 
                Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号