总结
豆包 AI 助手文章总结

python类参数定义及数据扩展方式unsqueeze/expand

WBOY
发布: 2022-08-24 13:32:40
转载
2678人浏览过

【相关推荐:python3视频教程

类的参数定义

将conda环境设置为ai,conda activate ai

这个文件的由来:

由于在yolov1的pytorch实现的损失函数中,看到继承了nn.Module,并且其中两个参数不像c++那里指定类型,那么他们的类型是哪里来的

这里就是在探索这样一件事

立即学习Python免费学习笔记(深入)”;

操作逻辑:

  • 先在类中定义了构造函数以及一个自定义函数;
  • 构造函数定义了属性S、B,自定义函数引入两个参数,对两个参数进行调用
    • 这里就说明参数的结构是怎么样的,取决于参数被调用了什么东西,比如这里调用了N = box1.size(0) M = box2.size(0)说明了它是类似一个矩阵的东西,对应的box1的定义就是`torch.rand(10,4)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

#探究属性S,B是如何产生的,以及box1、box2是如何产生的、如何调用
class yoloLoss(nn.Module):
    def __init__(self,S,B):
        self.S=S
        self.B=B
    def compute_iot(self,box1,box2):
        N = box1.size(0)  #调用方式就表示了变量是什么类型,这里是一个张量,其中每个元素是一个tensor,所以是N*4的张量
        M = box2.size(0)
        print(M,N)

yoloLoss1 =yoloLoss(10, 11)
yoloLoss1.compute_iot(torch.rand(10,4),torch.rand(11,4))
登录后复制

数据扩展

探究unsqueeze以及expand的使用方法,unsqueeze可以增加一个纬度,但是维度的siz只是1而已,而expand就可以将数据进行复制,将数据变为n

# 获得一开始的初始化数值:tensor([[a1,a2,a3]])
nn1=torch.rand(1,3)
print(nn1)
# unsqueeze是解压的意思,在第i个维度上进行扩展,将其扩展为tensor([[[a1,a2,a3]]])
nn1=nn1.unsqueeze(0)
print("*"*100)
print(nn1)
#利用expand对数据进行扩展
nn1=nn1.expand(1,3,3)
print("*"*100)
print(nn1)
登录后复制

【相关推荐:python3视频教程

以上就是python类参数定义及数据扩展方式unsqueeze/expand的详细内容,更多请关注php中文网其它相关文章!

python速学教程(入门到精通)
python速学教程(入门到精通)

python怎么学习?python怎么入门?python在哪学?python怎么学才快?不用担心,这里为大家提供了python速学教程(入门到精通),有需要的小伙伴保存下载就能学习啦!

下载
相关标签:
来源:脚本之家网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
豆包 AI 助手文章总结
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习
PHP中文网抖音号
发现有趣的

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号