轻量级Vision-Transformer:EdgeViTs复现

P粉084495128
发布: 2025-07-29 09:32:11
原创
479人浏览过
本文聚焦轻量级Vision-Transformer模型EdgeViTs的复现。EdgeViTs为适配移动设备,采用分层金字塔结构,设计Local-Global-Local(LGL)瓶颈,通过局部聚合、全局稀疏注意力和局部传播操作,在减少计算量的同时保留全局与局部上下文信息。文中给出模型各组件及整体架构的Paddle实现代码,并基于Flowers数据集进行训练验证。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

轻量级vision-transformer:edgevits复现 - php中文网

轻量级Vision-Transformer:EdgeViTs复现

摘要

  在计算机视觉领域,基于Self-attention的模型(如(ViTs))已经成为CNN之外的一种极具竞争力的架构。尽管越来越强的变种具有越来越高的识别精度,但由于Self-attention的二次复杂度,现有的ViT在计算和模型大小方面都有较高的要求。 虽然之前的CNN的一些成功的设计选择(例如,卷积和分层结构)已经被引入到最近的ViT中,但它们仍然不足以满足移动设备有限的计算资源需求。这促使人们最近尝试开发基于最先进的MobileNet-v2的轻型MobileViT,但MobileViT与MobileNet-v2仍然存在性能差距。 在这项工作中,作者进一步推进这一研究方向,引入了EdgeViTs,一个新的轻量级ViTs家族,也是首次使基于Self-attention的视觉模型在准确性和设备效率之间的权衡中达到最佳轻量级CNN的性能。

1 EdgeViTs

1.1 总体架构

  为了设计适用于移动/边缘设备的轻量级ViT,作者采用了最近ViT变体中使用的分层金字塔结构(图2(a))。Pyramid Transformer模型通常在不同阶段降低了空间分辨率同时也扩展了通道维度。每个阶段由多个基于Transformer Block处理相同形状的张量,类似ResNet的层次设计结构。

  在这项工作中,作者深入到Transformer Block,并引入了一个比较划算的Bottlneck,Local-Global-Local(LGL)(图2(b))。LGL通过一个稀疏注意力模块进一步减少了Self-attention的开销(图2(c)),实现了更好的准确性-延迟平衡。

轻量级Vision-Transformer:EdgeViTs复现 - php中文网        

1.2 Local-Global-Local bottleneck(LGL)

  与以前在每个空间位置执行Self-attention的Transformer Block相比,LGL Bottleneck只对输入Token的子集计算Self-attention,但支持完整的空间交互,如在标准的Multi-Head Self-attention(MHSA)中。既会减少Token的作用域,同时也保留建模全局和局部上下文的底层信息流。

  为了实现这一点,作者将Self-attention分解为连续的模块,处理不同范围内的空间Token(图2(b))。

  这里引入了3种有效的操作:

商汤商量
商汤商量

商汤科技研发的AI对话工具,商量商量,都能解决。

商汤商量36
查看详情 商汤商量
  • Local aggregation:仅集成来自局部近似Token信号的局部聚合
  • Global sparse attention:建模一组代表性Token之间的长期关系,其中每个Token都被视为一个局部窗口的代表;
  • Local propagation:将委托学习到的全局上下文信息扩散到具有相同窗口的非代表Token。

轻量级Vision-Transformer:EdgeViTs复现 - php中文网        

  • Local aggregation

  对于每个Token,利用Depth-wise和Point-wise卷积在大小为k×k的局部窗口中聚合信息(图3(a))。

  • Global sparse attention

  对均匀分布在空间中的稀疏代表性Token集进行采样,每个r×r窗口有一个代表性Token。这里,r表示子样本率。然后,只对这些被选择的Token应用Self-attention(图3(b))。这与所有现有的ViTs不同,在那里,所有的空间Token都作为Self-attention计算中的query被涉及到。

  • Local propagation

  通过转置卷积将代表性 Token 中编码的全局上下文信息传播到它们的相邻的 Token 中(图 3(c))。

轻量级Vision-Transformer:EdgeViTs复现 - php中文网        

2 代码复现

In [1]
import paddleimport paddle.nn as nnfrom paddle.nn import Conv2D  as Conv2dfrom paddle.nn import BatchNorm2D  as BatchNorm2dfrom paddle.nn import Linearfrom paddle.nn import AvgPool2D as AvgPool2dfrom paddle.nn import Conv2DTranspose as ConvTranspose2dfrom paddle.nn import LayerNorm, GELU
登录后复制
   
In [2]
class Residual(nn.Layer):
    def __init__(self, module):
        super().__init__()
        self.module = module    
    def forward(self, x):
        return x + self.module(x)class LocalAgg(nn.Layer):  
    def __init__(self, dim):
        super().__init__()
        self.conv1 = Conv2d(dim, dim, 1)  
        self.conv2 = Conv2d(dim, dim, 3, padding=1, groups=dim)  
        self.conv3 = Conv2d(dim, dim, 1)  
        self.norm1 = BatchNorm2d(dim)  
        self.norm2 = BatchNorm2d(dim)  
          

    def forward(self, x):  
        """  
        [B, C, H, W] = x.shape  
        """  
        x = self.conv1(self.norm1(x))  
        x = self.conv2(x)  
        x = self.conv3(self.norm2(x))  
        return x  

class GlobalSparseAttn(nn.Layer):  
    def __init__(self, dim, sample_rate = 4, scale = 1):
        super().__init__()  
        self.head_dim = int(48)//int(1)
        self.num_heads = int(1)
        self.scale = scale  
        self.qkv = Linear(dim, dim * 3)  
        self.sampler = AvgPool2d(1, stride=sample_rate)  
        self.LocalProp = ConvTranspose2d(dim, dim, kernel_size=sample_rate, stride=sample_rate, groups=dim  
        )  
        self.proj = Linear(dim, dim)  


    def forward(self, x):  
        """  
        [B, C, H, W] = x.shape  
        """  
        x = self.sampler(x)
        [B, C, H, W] = x.shape
        x = x.flatten(2)
        x = x.transpose([0,2,1])

        x = self.qkv(x)
        x = x.transpose([0, 2, 1])
        x = x.reshape([1, 144, 14, 14])
        q, k, v = x.reshape([B, self.num_heads, -1, H*W]).split([self.head_dim, self.head_dim, self.head_dim], axis=2)
       
        attn = (q.transpose([0, 1, 3, 2]) @ k)

        attn = nn.functional.softmax(attn)

        x = v  @  attn.transpose([0, 1, 3, 2])

        x = x.reshape([B, -1, H, W])

        x = self.LocalProp(x)  
       
        x = paddle.nn.functional.layer_norm(x, x.shape[1:])
        x = x.flatten(2)
        x = x.transpose([0,2,1])
        x = self.proj(x)  
        x = x.transpose([0,2,1])
        x = x.reshape([1, 48, 56, 56])        return x  

class DownSampleLayer(nn.Layer):  
    def __init__(self, dim_in=3, dim_out=48, downsample_rate=4):  
        super().__init__()
        self.downsample = Conv2d(dim_in, dim_out, kernel_size=downsample_rate, stride=  
        downsample_rate)  

    def forward(self, x):  
        x = self.downsample(x)
        x = paddle.nn.functional.layer_norm(x, x.shape[1:])        return x  

class PatchEmbed(nn.Layer):  
    def __init__(self, dim):
        super().__init__()
        self.embed = Conv2d(dim, dim, 3, padding=1, groups=dim)  
    def forward(self, x):  
        return x + self.embed(x)  

class FFN(nn.Layer):  
    def __init__(self, dim=3156):
        super().__init__()  
        self.fc1 = nn.Linear(dim, dim*4)  
        self.fc2 = nn.Linear(dim*4, dim)  
          

    def forward(self, x):
        x = x.flatten(2)
        x = x.transpose([0,2,1])
       
        x = self.fc1(x)  
        x = nn.functional.gelu(x) 
        x = self.fc2(x) 
       
        x = x.transpose([0,2,1])
        x = x.reshape([1, 48, 56, 56])        return x
登录后复制
   
In [ ]
class EdgeViT(nn.Layer):
    def __init__(self, dim_in=3, dim_out=48, downsample_rate=4, dim=48):
        super().__init__()

       
        self.downsample1 = DownSampleLayer(dim_in=3, dim_out=48, downsample_rate=4)
        self.patchembeding1 = PatchEmbed(dim=48)
        self.residual_add1 = Residual(LocalAgg(dim=48))
        self.residual_add1_1 = Residual(FFN(dim=48))

        self.patchembeding2 = PatchEmbed(dim=48)
        self.residual_add2 = Residual(GlobalSparseAttn(dim=48))
        self.fc = nn.Linear(150528,103)    def forward(self, x):

        x = self.downsample1(x)
        x = self.patchembeding1(x)
        x = self.residual_add1(x)
        x = self.residual_add1_1(x)
        x = self.patchembeding2(x)
        x = self.residual_add2(x)
        x = paddle.reshape(x,shape=[-1,48*56*56])        # x = x.transpose([0,2,1])
        # print(x.shape)
        x = self.fc(x)        return x
登录后复制
   
In [4]
cnn = EdgeViT()



paddle.summary(cnn,(1,3,224,224))
登录后复制
       
[1, 150528]
------------------------------------------------------------------------------
   Layer (type)        Input Shape          Output Shape         Param #    
==============================================================================
     Conv2D-1       [[1, 3, 224, 224]]    [1, 48, 56, 56]         2,352     
DownSampleLayer-1   [[1, 3, 224, 224]]    [1, 48, 56, 56]           0       
     Conv2D-2       [[1, 48, 56, 56]]     [1, 48, 56, 56]          480      
   PatchEmbed-1     [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
  BatchNorm2D-1     [[1, 48, 56, 56]]     [1, 48, 56, 56]          192      
     Conv2D-3       [[1, 48, 56, 56]]     [1, 48, 56, 56]         2,352     
     Conv2D-4       [[1, 48, 56, 56]]     [1, 48, 56, 56]          480      
  BatchNorm2D-2     [[1, 48, 56, 56]]     [1, 48, 56, 56]          192      
     Conv2D-5       [[1, 48, 56, 56]]     [1, 48, 56, 56]         2,352     
    LocalAgg-1      [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
    Residual-1      [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
     Linear-1        [[1, 3136, 48]]       [1, 3136, 192]         9,408     
     Linear-2        [[1, 3136, 192]]      [1, 3136, 48]          9,264     
      FFN-1         [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
    Residual-2      [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
     Conv2D-6       [[1, 48, 56, 56]]     [1, 48, 56, 56]          480      
   PatchEmbed-2     [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
   AvgPool2D-1      [[1, 48, 56, 56]]     [1, 48, 14, 14]           0       
     Linear-3         [[1, 196, 48]]       [1, 196, 144]          7,056     
Conv2DTranspose-1   [[1, 48, 14, 14]]     [1, 48, 56, 56]          816      
     Linear-4        [[1, 3136, 48]]       [1, 3136, 48]          2,352     
GlobalSparseAttn-1  [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
    Residual-3      [[1, 48, 56, 56]]     [1, 48, 56, 56]           0       
     Linear-5         [[1, 150528]]           [1, 103]         15,504,487   
==============================================================================
Total params: 15,542,263
Trainable params: 15,541,879
Non-trainable params: 384
------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 27.85
Params size (MB): 59.29
Estimated Total Size (MB): 87.71
------------------------------------------------------------------------------
登录后复制
       
{'total_params': 15542263, 'trainable_params': 15541879}
登录后复制
               

3 模型训练

  论文的实验是基于ImageNet数据集进行的,但是目前平台不具备拉取该数据集的能力,故这里采用了Cifar10作为模型验证数据集,仅做调通,不设置对比实验,因为在小数据集上无对比性。

In [5]
import paddlefrom paddle.vision.datasets import Flowersfrom paddle.vision.transforms import Compose, Normalize, Resize, Transpose, ToTensor


normalize = Normalize(mean=[0.5, 0.5, 0.5],
                    std=[0.5, 0.5, 0.5],
                    data_format='HWC')
transform = Compose([ToTensor(), Normalize(), Resize(size=(224,224))])

cifar10_train = paddle.vision.datasets.Flowers(mode='train',
                                               transform=transform)
cifar10_test = paddle.vision.datasets.Flowers(mode='test',
                                              transform=transform)# 构建训练集数据加载器train_loader = paddle.io.DataLoader(cifar10_train, batch_size=1, shuffle=True)# 构建测试集数据加载器test_loader = paddle.io.DataLoader(cifar10_test, batch_size=1, shuffle=True)print('=============train dataset=============')for image, label in cifar10_train:    print('image shape: {}, label: {}'.format(image.shape, label))    break
登录后复制
       
=============train dataset=============
image shape: [3, 224, 224], label: [1]
登录后复制
       
In [ ]
from paddle.metric import Accuracy


model = paddle.Model(EdgeViT())
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())

model.prepare(
    optim,
    paddle.nn.CrossEntropyLoss(),
    Accuracy()
    )

model.fit(train_data=train_loader,
        eval_data=test_loader,
        epochs=2,
        verbose=1
        )
登录后复制
   

以上就是轻量级Vision-Transformer:EdgeViTs复现的详细内容,更多请关注php中文网其它相关文章!

Edge浏览器
Edge浏览器

Edge浏览器是由Microsoft(微软中国)官方推出的全新一代手机浏览器。Edge浏览器内置强大的搜索引擎,一站搜索全网,搜文字、搜图片,快速识别,帮您找到想要的内容。有需要的小伙伴快来保存下载体验吧!

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习
PHP中文网抖音号
发现有趣的

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号