【ICLR 2022】RegionViT:从区域到局部的ViT

P粉084495128
发布: 2025-08-01 13:39:18
原创
273人浏览过
RegionViT提出从区域到局部的视觉Transformer结构,以金字塔结构引入区域到局部注意替代全局自注意。先生成不同贴片大小的区域和局部令牌,经区域自注意提取全局信息,再通过局部自注意传递给局部令牌,结合相对位置编码。在多视觉任务上表现优异,实现高效且兼具全局感受野与局部性。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

【iclr 2022】regionvit:从区域到局部的vit - php中文网

RegionViT:从区域到局部的ViT

摘要

        近年来,视觉Transformer(VIT)在图像分类方面显示出了与卷积神经网络(CNNs)相当的强大能力。 然而,原始ViT只是直接从自然语言处理中继承了相同的体系结构,而自然语言处理通常没有针对视觉应用进行优化。 基于此,本文提出了一种新的视觉Transformer结构,该结构采用金字塔结构,在视觉Transformer中引入了新的区域到局部的注意而不是全局的自注意。 更具体地说,我们的模型首先从具有不同贴片大小的图像中生成区域令牌和局部令牌,其中每个区域令牌与基于空间位置的一组局部令牌相关联。 区域到局部注意包括两个步骤:首先,区域自注意在所有区域令牌之间提取全局信息,然后局部自注意通过自注意在一个区域令牌和相关的局部令牌之间交换信息。 因此,即使局部自我注意的范围局限于局部区域,但它仍然可以接收到全局信息。 在图像分类、目标和关键点检测、语义分割和动作识别等四个视觉任务上的大量实验表明,我们的方法优于或与包括许多并行工作在内的现有ViT变体相当。

1. RegionViT

        由于全局自注意力计算太贵,很多工作提出使用局部自注意力,即在一个小区域内进行全局自注意力,但是局部自注意力又会带来另外一个问题,即感受野过小。为此,本文提出了一种新的从粗到细的Transformer——RegionViT。通过区域令牌进行全局交互,并将区域令牌包含的全局信息通过局部自注意力传递给对应的局部Token。本文方法的整体架构如图2所示:

【ICLR 2022】RegionViT:从区域到局部的ViT - php中文网        

        本文的核心模块是区域到局部的Transformer编码器,主要思想就是通过区域令牌进行全局交互,并将区域令牌包含的全局信息通过局部自注意力传递给对应的局部Token,具体操作如下公式所示:

yrd=xrd1+RSA(LN(xrd1)),yi,jd=[yri,jd{xli,j,m,nd1}m,nM]zi,jd=yi,jd+LSA(LN(yi,jd)),xi,jd=zi,jd+FFN(LN(zi,jd))yrd=xrd−1+RSA(LN(xrd−1)),yi,jd=[yri,jd∥{xli,j,m,nd−1}m,n∈M]zi,jd=yi,jd+LSA(LN(yi,jd)),xi,jd=zi,jd+FFN(LN(zi,jd))

奇域
奇域

奇域是一个专注于中式美学的国风AI绘画创作平台

奇域30
查看详情 奇域

        局部性是理解视觉内容的重要线索。因此,本文提出使用相对位置编码,值得注意的是,该位置编码只添加到局部Token中,不添加区域Token到局部Token的位置编码。具体公式如下:

a(xm,ym),(xn,yn)=softmax(q(xm,ym)k(xn,yn)T+b(xmxn,ymyn)),a(xm,ym),(xn,yn)=softmax(q(xm,ym)k(xn,yn)T+b(xm−xn,ym−yn)),

【ICLR 2022】RegionViT:从区域到局部的ViT - php中文网        

2. 代码复现

2.1 下载并导入所需的库

In [ ]
%matplotlib inlineimport paddleimport numpy as npimport matplotlib.pyplot as pltfrom paddle.vision.datasets import Cifar10from paddle.vision.transforms import Transposefrom paddle.io import Dataset, DataLoaderfrom paddle import nnimport paddle.nn.functional as Fimport paddle.vision.transforms as transformsimport osimport matplotlib.pyplot as pltfrom matplotlib.pyplot import figureimport itertoolsfrom functools import partial
登录后复制
   

2.2 创建数据集

In [3]
train_tfm = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
登录后复制
   
In [4]
paddle.vision.set_image_backend('cv2')# 使用Cifar10数据集train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)print("train_dataset: %d" % len(train_dataset))print("val_dataset: %d" % len(val_dataset))
登录后复制
       
train_dataset: 50000
val_dataset: 10000
登录后复制
       
In [5]
batch_size=64
登录后复制
   
In [6]
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
登录后复制
   

2.3 模型的创建

2.3.1 标签平滑

In [7]
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss        return loss.mean()
登录后复制
   

2.3.2 DropPath

In [8]
def drop_path(x, drop_prob=0.0, training=False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0.0 or not training:        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor    return outputclass DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
登录后复制
   

2.3.3 RegionViT模型的创建

In [9]
class LayerNorm2D(nn.Layer):
    def __init__(self, channels, eps=1e-5, elementwise_affine=True):
        super().__init__()

        self.channels = channels
        self.eps = paddle.to_tensor(eps)
        self.elementwise_affine = elementwise_affine        if self.elementwise_affine:
            self.weight = self.create_parameter(shape=(1, channels, 1, 1), default_initializer=nn.initializer.Constant(1.0))
            self.bias = self.create_parameter(shape=(1, channels, 1, 1), default_initializer=nn.initializer.Constant(0.0))        else:
            self.register_buffer('weight', None)
            self.register_buffer('bias', None)    def forward(self, input):
        mean = input.mean(1, keepdim=True)
        std = paddle.sqrt(input.var(1, unbiased=False, keepdim=True) + self.eps)
        out = (input - mean) / std        if self.elementwise_affine:
            out = out * self.weight + self.bias        return out
登录后复制
   
In [10]
class AttentionWithRelPos(nn.Layer):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
                 attn_map_dim=None, num_cls_tokens=1):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.num_cls_tokens = num_cls_tokens        if attn_map_dim is not None:
            one_dim = attn_map_dim[0]
            rel_pos_dim = (2 * one_dim - 1)
            self.rel_pos = self.create_parameter(shape=(num_heads, rel_pos_dim ** 2), default_initializer=nn.initializer.Constant(0.0))
            tmp = paddle.arange(rel_pos_dim ** 2).reshape((rel_pos_dim, rel_pos_dim))
            out = []
            offset_x = offset_y = one_dim // 2
            for y in range(one_dim):                for x in range(one_dim):                    for dy in range(one_dim):                        for dx in range(one_dim):
                            out.append(tmp[dy - y + offset_y, dx - x + offset_x])
            self.rel_pos_index = paddle.to_tensor(out, dtype=paddle.int32)
            tn = nn.initializer.TruncatedNormal(std=.02)
            tn(self.rel_pos)        else:
            self.rel_pos = None

    def forward(self, x, patch_attn=False, mask=None):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape((B, N, 3, self.num_heads, C // self.num_heads)).transpose([2, 0, 3, 1, 4])
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale        if self.rel_pos is not None and patch_attn:            # use for the indicating patch + cls:
            rel_pos = self.rel_pos[:, self.rel_pos_index].reshape((self.num_heads, N - self.num_cls_tokens, N - self.num_cls_tokens))
            attn[:, :, self.num_cls_tokens:, self.num_cls_tokens:] = attn[:, :, self.num_cls_tokens:, self.num_cls_tokens:] + rel_pos        if mask is not None:            ## mask is only (BH_sW_s)(ksks)(ksks), need to expand it
            mask = mask.unsqueeze(1).expand((-1, self.num_heads, -1, -1))
            attn = attn.masked_fill(mask == 0, paddle.finfo(attn.dtype).min)

        attn = F.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose([0, 2, 1, 3]).reshape((B, N, C))
        x = self.proj(x)
        x = self.proj_drop(x)        return x
登录后复制
   
In [11]
def to_2tuple(x):
    return (x, x)class PatchEmbed(nn.Layer):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, patch_conv_type='linear'):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches        if patch_conv_type == '3conv':            if patch_size[0] == 4:
                tmp = [
                    nn.Conv2D(in_chans, embed_dim // 4, kernel_size=3, stride=2, padding=1),
                    LayerNorm2D(embed_dim // 4),
                    nn.GELU(),
                    nn.Conv2D(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1),
                    LayerNorm2D(embed_dim // 2),
                    nn.GELU(),
                    nn.Conv2D(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1),
                ]            else:                raise ValueError(f"Unknown patch size {patch_size[0]}")
            self.proj = nn.Sequential(*tmp)        else:            if patch_conv_type == '1conv':
                kernel_size = (2 * patch_size[0], 2 * patch_size[1])
                stride = (patch_size[0], patch_size[1])
                padding = (patch_size[0] - 1, patch_size[1] - 1)            else:
                kernel_size = patch_size
                stride = patch_size
                padding = 0

            self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=kernel_size,
                                  stride=stride, padding=padding)    def forward(self, x, extra_padding=False):
        B, C, H, W = x.shape        # FIXME look at relaxing size constraints
        # assert H == self.img_size[0] and W == self.img_size[1], \
        #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        if extra_padding and (H % self.patch_size[0] != 0 or W % self.patch_size[1] != 0):
            p_l = (self.patch_size[1] - W % self.patch_size[1]) // 2
            p_r = (self.patch_size[1] - W % self.patch_size[1]) - p_l
            p_t = (self.patch_size[0] - H % self.patch_size[0]) // 2
            p_b = (self.patch_size[0] - H % self.patch_size[0]) - p_t
            x = F.pad(x, (p_l, p_r, p_t, p_b))
        x = self.proj(x)        return x
登录后复制
   
In [12]
class Mlp(nn.Layer):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Linear(in_features, hidden_features, bias_attr=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features, bias_attr=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)        return x
登录后复制
   
In [13]
class R2LAttentionPlusFFN(nn.Layer):

    def __init__(self, input_channels, output_channels, kernel_size, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop_path=0., attn_drop=0., drop=0.,
                 cls_attn=True):
        super().__init__()        if not isinstance(kernel_size, (tuple, list)):
            kernel_size = [(kernel_size, kernel_size), (kernel_size, kernel_size), 0]
        self.kernel_size = kernel_size        if cls_attn:
            self.norm0 = norm_layer(input_channels)        else:
            self.norm0 = None

        self.norm1 = norm_layer(input_channels)
        self.attn = AttentionWithRelPos(
            input_channels, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
            attn_map_dim=(kernel_size[0][0], kernel_size[0][1]), num_cls_tokens=1)        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(input_channels)
        self.mlp = Mlp(in_features=input_channels, hidden_features=int(output_channels * mlp_ratio), out_features=output_channels, act_layer=act_layer, drop=drop)

        self.expand = nn.Sequential(
            norm_layer(input_channels),
            act_layer(),
            nn.Linear(input_channels, output_channels)
        ) if input_channels != output_channels else None

        self.output_channels = output_channels
        self.input_channels = input_channels    def forward(self, xs):
        out, B, H, W, mask = xs
        cls_tokens = out[:, 0:1, ...]

        C = cls_tokens.shape[-1]
        cls_tokens = cls_tokens.reshape((B, -1, C))  # (N)x(H/sxW/s)xC

        if self.norm0 is not None:
            cls_tokens = cls_tokens + self.drop_path(self.attn(self.norm0(cls_tokens)))  # (N)x(H/sxK/s)xC

        # ks, stride, padding = self.kernel_size
        cls_tokens = cls_tokens.reshape((-1, 1, C))  # (NxH/sxK/s)x1xC

        out = paddle.concat((cls_tokens, out[:, 1:, ...]), axis=1)
        tmp = out

        tmp = tmp + self.drop_path(self.attn(self.norm1(tmp), patch_attn=True, mask=mask))
        identity = self.expand(tmp) if self.expand is not None else tmp
        tmp = identity + self.drop_path(self.mlp(self.norm2(tmp)))        return tmp
登录后复制
   
In [14]
class Projection(nn.Layer):
    def __init__(self, input_channels, output_channels, act_layer, mode='sc'):
        super().__init__()
        tmp = []        if 'c' in mode:
            ks = 2 if 's' in mode else 1
            if ks == 2:
                stride = ks
                ks = ks + 1
                padding = ks // 2
            else:
                stride = ks
                padding = 0

            if input_channels == output_channels and ks == 1:
                tmp.append(nn.Identity())            else:
                tmp.extend([
                    LayerNorm2D(input_channels),
                    act_layer(),
                ])
                tmp.append(nn.Conv2D(in_channels=input_channels, out_channels=output_channels, kernel_size=ks, stride=stride, padding=padding, groups=input_channels))

        self.proj = nn.Sequential(*tmp)
        self.proj_cls = self.proj    def forward(self, xs):
        cls_tokens, patch_tokens = xs        # x: BxCxHxW
        cls_tokens = self.proj_cls(cls_tokens)
        patch_tokens = self.proj(patch_tokens)        return cls_tokens, patch_tokens
登录后复制
   
In [15]
def convert_to_flatten_layout(cls_tokens, patch_tokens, ws):
    """
    Convert the token layer in a flatten form, it will speed up the model.

    Furthermore, it also handle the case that if the size between regional tokens and local tokens are not consistent.
    """
    # padding if needed, and all paddings are happened at bottom and right.
    B, C, H, W = patch_tokens.shape
    _, _, H_ks, W_ks = cls_tokens.shape
    need_mask = False
    p_l, p_r, p_t, p_b = 0, 0, 0, 0
    if H % (H_ks * ws) != 0 or W % (W_ks * ws) != 0:
        p_l, p_r = 0, W_ks * ws - W
        p_t, p_b = 0, H_ks * ws - H
        patch_tokens = F.pad(patch_tokens, (p_l, p_r, p_t, p_b))
        need_mask = True

    B, C, H, W = patch_tokens.shape
    kernel_size = [H // H_ks, W // W_ks]
    tmp = F.unfold(patch_tokens, kernel_sizes=kernel_size, strides=kernel_size, paddings=[0, 0])  # Nx(Cxksxks)x(H/sxK/s)
    patch_tokens = tmp.transpose([0, 2, 1]).reshape((-1, C, kernel_size[0] * kernel_size[1])).transpose([0, 2, 1])  # (NxH/sxK/s)x(ksxks)xC

    if need_mask:
        BH_sK_s, ksks, C = patch_tokens.shape
        H_s, W_s = H // ws, W // ws
        mask = paddle.ones(BH_sK_s // B, 1 + ksks, 1 + ksks, dtype='float32')
        right = paddle.zeros(1 + ksks, 1 + ksks, dtype='float32')
        tmp = paddle.zeros(ws, ws, dtype='float32')
        tmp[0:(ws - p_r), 0:(ws - p_r)] = 1.
        tmp = tmp.repeat(ws, ws)
        right[1:, 1:] = tmp
        right[0, 0] = 1
        right[0, 1:] = paddle.to_tensor([1.] * (ws - p_r) + [0.] * p_r).repeat(ws)
        right[1:, 0] = paddle.to_tensor([1.] * (ws - p_r) + [0.] * p_r).repeat(ws)
        bottom = paddle.zeros_like(right)
        bottom[0:ws * (ws - p_b) + 1, 0:ws * (ws - p_b) + 1] = 1.
        bottom_right = copy.deepcopy(right)
        bottom_right[0:ws * (ws - p_b) + 1, 0:ws * (ws - p_b) + 1] = 1.

        mask[W_s - 1:(H_s - 1) * W_s:W_s, ...] = right
        mask[(H_s - 1) * W_s:, ...] = bottom
        mask[-1, ...] = bottom_right
        mask = mask.repeat(B, 1, 1)    else:
        mask = None

    cls_tokens = cls_tokens.flatten(2).transpose([0, 2, 1])  # (N)x(H/sxK/s)xC
    cls_tokens = cls_tokens.reshape((-1, 1, cls_tokens.shape[-1]))  # (NxH/sxK/s)x1xC

    out = paddle.concat((cls_tokens, patch_tokens), axis=1)    return out, mask, p_l, p_r, p_t, p_b, B, C, H, Wdef convert_to_spatial_layout(out, output_channels, B, H, W, kernel_size, mask, p_l, p_r, p_t, p_b):
    """
    Convert the token layer from flatten into 2-D, will be used to downsample the spatial dimension.
    """
    cls_tokens = out[:, 0:1, ...]
    patch_tokens = out[:, 1:, ...]    # cls_tokens: (BxH/sxW/s)x(1)xC, patch_tokens: (BxH/sxW/s)x(ksxks)xC
    C = output_channels
    kernel_size = kernel_size[0]
    H_ks = H // kernel_size[0]
    W_ks = W // kernel_size[1]    # reorganize data, need to convert back to cls_tokens: BxCxH/sxW/s, patch_tokens: BxCxHxW
    cls_tokens = cls_tokens.reshape((B, -1, C)).transpose([0, 2, 1]).reshape((B, C, H_ks, W_ks))
    patch_tokens = patch_tokens.transpose([0, 2, 1]).reshape((B, -1, kernel_size[0] * kernel_size[1] * C)).transpose([0, 2, 1])
    patch_tokens = F.fold(patch_tokens, [H, W], kernel_sizes=kernel_size, strides=kernel_size, paddings=[0, 0])    if mask is not None:        if p_b > 0:
            patch_tokens = patch_tokens[:, :, :-p_b, :]        if p_r > 0:
            patch_tokens = patch_tokens[:, :, :, :-p_r]    return cls_tokens, patch_tokens
登录后复制
   
In [16]
class ConvAttBlock(nn.Layer):
    def __init__(self, input_channels, output_channels, kernel_size, num_blocks, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None, pool='sc',
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop_path_rate=(0.,), attn_drop_rate=0., drop_rate=0.,
                 cls_attn=True, peg=False):
        super().__init__()
        tmp = []        if pool:
            tmp.append(Projection(input_channels, output_channels, act_layer=act_layer, mode=pool))        for i in range(num_blocks):
            kernel_size_ = kernel_size
            tmp.append(R2LAttentionPlusFFN(output_channels, output_channels, kernel_size_, num_heads, mlp_ratio, qkv_bias, qk_scale,
                                           act_layer=act_layer, norm_layer=norm_layer, drop_path=drop_path_rate[i], attn_drop=attn_drop_rate, drop=drop_rate,
                                           cls_attn=cls_attn))

        self.block = nn.LayerList(tmp)
        self.output_channels = output_channels
        self.ws = kernel_size        if not isinstance(kernel_size, (tuple, list)):
            kernel_size = [[kernel_size, kernel_size], [kernel_size, kernel_size], 0]
        self.kernel_size = kernel_size

        self.peg = nn.Conv2D(output_channels, output_channels, kernel_size=3, padding=1, groups=output_channels, bias=False) if peg else None

    def forward(self, xs):
        cls_tokens, patch_tokens = xs
        cls_tokens, patch_tokens = self.block[0]((cls_tokens, patch_tokens))
        out, mask, p_l, p_r, p_t, p_b, B, C, H, W = convert_to_flatten_layout(cls_tokens, patch_tokens, self.ws)        for i in range(1, len(self.block)):
            blk = self.block[i]

            out = blk((out, B, H, W, mask))            if self.peg is not None and i == 1:
                cls_tokens, patch_tokens = convert_to_spatial_layout(out, self.output_channels, B, H, W, self.kernel_size, mask, p_l, p_r, p_t, p_b)
                cls_tokens = cls_tokens + self.peg(cls_tokens)
                patch_tokens = patch_tokens + self.peg(patch_tokens)
                out, mask, p_l, p_r, p_t, p_b, B, C, H, W = convert_to_flatten_layout(cls_tokens, patch_tokens, self.ws)

        cls_tokens, patch_tokens = convert_to_spatial_layout(out, self.output_channels, B, H, W, self.kernel_size, mask, p_l, p_r, p_t, p_b)        return cls_tokens, patch_tokens
登录后复制
   
In [17]
class RegionViT(nn.Layer):
    """
    Note:
        The variable naming mapping between codes and papers:
        - cls_tokens -> regional tokens
        - patch_tokens -> local tokens
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=(768,), depth=(12,),
                 num_heads=(12,), mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=partial(nn.LayerNorm, epsilon=1e-6),                 # regionvit parameters
                 kernel_sizes=None, downsampling=None,
                 patch_conv_type='3conv',
                 computed_cls_token=True, peg=False,
                 det_norm=False):

        super().__init__()
        self.num_classes = num_classes
        self.kernel_sizes = kernel_sizes
        self.num_features = embed_dim[-1]  # num_features for consistency with other models
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.img_size = img_size
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim[0],
            patch_conv_type=patch_conv_type)        if not isinstance(mlp_ratio, (list, tuple)):
            mlp_ratio = [mlp_ratio] * len(depth)

        self.computed_cls_token = computed_cls_token
        self.cls_token = PatchEmbed(
            img_size=img_size, patch_size=patch_size * kernel_sizes[0], in_chans=in_chans, embed_dim=embed_dim[0],
            patch_conv_type='linear'
        )
        self.pos_drop = nn.Dropout(p=drop_rate)
        total_depth = sum(depth)
        dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, total_depth)]  # stochastic depth decay rule
        dpr_ptr = 0
        self.layers = nn.LayerList()        for i in range(len(embed_dim) - 1):
            curr_depth = depth[i]
            dpr_ = dpr[dpr_ptr: dpr_ptr + curr_depth]

            self.layers.append(
                ConvAttBlock(embed_dim[i], embed_dim[i + 1], kernel_size=kernel_sizes[i], num_blocks=depth[i], drop_path_rate=dpr_,
                             num_heads=num_heads[i], mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias, qk_scale=qk_scale,
                             pool=downsampling[i], norm_layer=norm_layer, attn_drop_rate=attn_drop_rate, drop_rate=drop_rate,
                             cls_attn=True, peg=peg)
            )
            dpr_ptr += curr_depth
        self.norm = norm_layer(embed_dim[-1])        # Classifier head
        self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()        if not computed_cls_token:
            tn = nn.initializer.TruncatedNormal(std=.02)
            tn(self.cls_token)

        self.det_norm = det_norm        if self.det_norm:            # add a norm layer for the outputs at each stage, for detection
            for i in range(4):
                layer = LayerNorm2D(embed_dim[1 + i])
                layer_name = f'norm{i}'
                self.add_module(layer_name, layer)

        self.apply(self._init_weights)    def _init_weights(self, m):
        tn = nn.initializer.TruncatedNormal(std=.02)
        ones = nn.initializer.Constant(1.0)
        zeros = nn.initializer.Constant(0.0)        if isinstance(m, nn.Linear):
            tn(m.weight)            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros(m.bias)        elif isinstance(m, nn.LayerNorm):
            zeros(m.bias)
            ones(m.weight)    def forward_features(self, x, detection=False):
        o_x = x
        x = self.patch_embed(x)        # B x branches x classes
        cls_tokens = self.cls_token(o_x, extra_padding=True)
        x = self.pos_drop(x)  # N C H W
        tmp_out = []        for idx, layer in enumerate(self.layers):
            cls_tokens, x = layer((cls_tokens, x))            if self.det_norm:
                norm_layer = getattr(self, f'norm{idx}')
                x = norm_layer(x)
            tmp_out.append(x)        if detection:            return tmp_out

        N, C, H, W = cls_tokens.shape
        cls_tokens = cls_tokens.reshape((N, C, -1)).transpose([0, 2, 1])
        cls_tokens = self.norm(cls_tokens)
        out = paddle.mean(cls_tokens, axis=1)        return out    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)        return x
登录后复制
   
In [18]
_model_cfg = {    'tiny': {        'img_size': 224,        'patch_conv_type': '3conv',        'patch_size': 4,        'embed_dim': [64, 64, 128, 256, 512],        'num_heads': [2, 4, 8, 16],        'mlp_ratio': 4.,        'depth': [2, 2, 8, 2],        'kernel_sizes': [7, 7, 7, 7],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],
    },    'small': {        'img_size': 224,        'patch_conv_type': '3conv',        'patch_size': 4,        'embed_dim': [96, 96, 192, 384, 768],        'num_heads': [3, 6, 12, 24],        'mlp_ratio': 4.,        'depth': [2, 2, 8, 2],        'kernel_sizes': [7, 7, 7, 7],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],
    },    'medium': {        'img_size': 224,        'patch_conv_type': '1conv',        'patch_size': 4,        'embed_dim': [96] + [96 * (2 ** i) for i in range(4)],        'num_heads': [3, 6, 12, 24],        'mlp_ratio': 4.,        'depth': [2, 2, 14, 2],        'kernel_sizes': [7, 7, 7, 7],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],
    },    'base': {        'img_size': 224,        'patch_conv_type': '1conv',        'patch_size': 4,        'embed_dim': [128, 128, 256, 512, 1024],        'num_heads': [4, 8, 16, 32],        'mlp_ratio': 4.,        'depth': [2, 2, 14, 2],        'kernel_sizes': [7, 7, 7, 7],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],
    },    'small_w14': {        'img_size': 224,        'patch_conv_type': '3conv',        'patch_size': 4,        'embed_dim': [96, 96, 192, 384, 768],        'num_heads': [3, 6, 12, 24],        'mlp_ratio': 4.,        'depth': [2, 2, 8, 2],        'kernel_sizes': [14, 14, 14, 14],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],
    },    'small_w14_peg': {        'img_size': 224,        'patch_conv_type': '3conv',        'patch_size': 4,        'embed_dim': [96, 96, 192, 384, 768],        'num_heads': [3, 6, 12, 24],        'mlp_ratio': 4.,        'depth': [2, 2, 8, 2],        'kernel_sizes': [14, 14, 14, 14],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],        'peg': True
    },    'base_w14': {        'img_size': 224,        'patch_conv_type': '1conv',        'patch_size': 4,        'embed_dim': [128, 128, 256, 512, 1024],        'num_heads': [4, 8, 16, 32],        'mlp_ratio': 4.,        'depth': [2, 2, 14, 2],        'kernel_sizes': [14, 14, 14, 14],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],
    },    'base_w14_peg': {        'img_size': 224,        'patch_conv_type': '1conv',        'patch_size': 4,        'embed_dim': [128, 128, 256, 512, 1024],        'num_heads': [4, 8, 16, 32],        'mlp_ratio': 4.,        'depth': [2, 2, 14, 2],        'kernel_sizes': [14, 14, 14, 14],  # 8x8, 4x4, 2x2, 1x1,
        'downsampling': ['c', 'sc', 'sc', 'sc'],        'peg': True
    },

}
登录后复制
   
In [19]
num_classes = 10def regionvit_tiny_224():
    model_cfg = _model_cfg['tiny']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_small_224():
    model_cfg = _model_cfg['small']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_small_w14_224():
    model_cfg = _model_cfg['small_w14']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_small_w14_peg_224():
    model_cfg = _model_cfg['small_w14_peg']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_medium_224():
    model_cfg = _model_cfg['medium']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_base_224():
    model_cfg = _model_cfg['base']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_base_w14_224():
    model_cfg = _model_cfg['base_w14']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return modeldef regionvit_base_w14_peg_224():
    model_cfg = _model_cfg['base_w14_peg']
    model = RegionViT(**model_cfg, num_classes=num_classes)    return model
登录后复制
   

2.3.4 模型的参数

In [ ]
model = regionvit_tiny_224()
paddle.summary(model, (1, 3, 224, 224))
登录后复制
   

【ICLR 2022】RegionViT:从区域到局部的ViT - php中文网        

2.4 训练

In [22]
learning_rate = 0.0001n_epochs = 100paddle.seed(42)
np.random.seed(42)
登录后复制
   
In [ ]
work_path = 'work/model'# RegionViT-Tinymodel = regionvit_tiny_224()

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0threshold = 0.0best_acc = 0.0val_acc = 0.0loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording lossacc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracyloss_iter = 0acc_iter = 0for epoch in range(n_epochs):    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = accuracy_manager.compute(logits, labels)
        accuracy_manager.update(acc)        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()

        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))    # ---------- Validation ----------
    model.eval()    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = val_accuracy_manager.compute(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))
登录后复制
   

【ICLR 2022】RegionViT:从区域到局部的ViT - php中文网        

2.5 结果分析

In [24]
def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
登录后复制
   
In [25]
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')
登录后复制
       
<Figure size 1000x600 with 1 Axes>
登录后复制
               
In [26]
plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')
登录后复制
       
<Figure size 1000x600 with 1 Axes>
登录后复制
               
In [27]
import time
work_path = 'work/model'model = regionvit_tiny_224()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
登录后复制
       
Throughout:678
登录后复制
       
In [28]
def get_cifar10_labels(labels):
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',        'horse', 'ship', 'truck']    return [text_labels[int(i)] for i in labels]
登录后复制
   
In [29]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()    for i, (ax, img) in enumerate(zip(axes, imgs)):        if paddle.is_tensor(img):
            ax.imshow(img.numpy())        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])    return axes
登录后复制
   
In [30]
work_path = 'work/model'X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = regionvit_tiny_224()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
登录后复制
       
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
登录后复制
       
<Figure size 2700x150 with 18 Axes>
登录后复制
               

总结

        本文提出了一种从区域到局部的一种从粗到细的Transformer,既具有全局的感受野,又具有局部性,实现简单高效。

以上就是【ICLR 2022】RegionViT:从区域到局部的ViT的详细内容,更多请关注php中文网其它相关文章!

最佳 Windows 性能的顶级免费优化软件
最佳 Windows 性能的顶级免费优化软件

每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。

下载
来源:php中文网
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
最新问题
开源免费商场系统广告
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板
关于我们 免责申明 意见反馈 讲师合作 广告合作 最新更新 English
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送
PHP中文网APP
随时随地碎片化学习
PHP中文网抖音号
发现有趣的

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号