注意力机制被视为视觉Transformer成功关键,但研究质疑其必要性。通过零参数零计算的Shift操作构建ShiftViT,替代注意力层,在分类、检测和分割等任务中表现媲美甚至优于Swin Transformer,表明注意力机制或非ViT成功关键,未来应关注其剩余部分。
☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

注意力机制被广泛认为是视觉Transformer成功的关键,因为它提供了一种灵活而强大的方式来建模空间关系。然而,注意力机制真的是ViT不可或缺的一部分吗?它能被其他替代品取代吗?为了揭开注意力机制的神秘面纱,我们将其简化为一个极其简单的例子:ZERO FLOP和ZERO parameter。具体来说,我们要重新审视Shift操作。它不包含任何参数或算术计算。唯一的操作是在相邻特性之间交换一小部分通道。基于这个简单的操作,我们构建了一个新的骨干网络,即ShiftViT,其中的注意层被Shift操作所取代。令人惊讶的是,ShiftViT在几个主流任务中工作得相当好,例如,分类,检测和分割。性能与强大的基线Swin Transformer相当,甚至更好。这些结果表明,注意力机制可能不是使ViT成功的关键因素。它甚至可以被零参数操作取代。在今后的工作中,我们应该更多地关注ViT的剩余部分。
为验证Transformer中的Attention机制是否是必要的,ShiftViT采用一个简单高效的Shift操作来代替Attention机制,Shift操作是沿上下左右对Shift部分进行偏移操作(与S2MLP特别相似,不同的是S2MLP将所有都采用Shift操作,而ShiftViT仅对一部分使用Shift操作):
z^[0:H,1:W,0:γC]z^[0:H,0:W−1,γC:2γC]z^[0:H−1,0:W,2γC:3γC]z^[1:H,0:W,3γC:4γC]z^[0:H,0:W,4γC:C]←z[0:H,0:W−1,0:γC]←z[0:H,1:W,γC:2γC]←z[1:H,0:W,2γC:3γC]←z[0:H−1,0:W,3γC:4γC]←z[0:H,0:W,4γC:C]

!pip install einops-0.3.0-py3-none-any.whl
%matplotlib inlineimport paddleimport numpy as npimport matplotlib.pyplot as pltfrom paddle.vision.datasets import Cifar10from paddle.vision.transforms import Transposefrom paddle.io import Dataset, DataLoaderfrom paddle import nnimport paddle.nn.functional as Fimport paddle.vision.transforms as transformsimport osimport matplotlib.pyplot as pltfrom matplotlib.pyplot import figurefrom einops.layers.paddle import Rearrange, Reducefrom einops import rearrange
train_tfm = transforms.Compose([
    transforms.Resize((230, 230)),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])paddle.vision.set_image_backend('cv2')# 使用Cifar10数据集train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)print("train_dataset: %d" % len(train_dataset))print("val_dataset: %d" % len(val_dataset))train_dataset: 50000 val_dataset: 10000
batch_size=128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing    def forward(self, pred, target):
        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss        return loss.mean()def drop_path(x, drop_prob=0.0, training=False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0.0 or not training:        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor    return outputclass DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)class MLP(nn.Layer):
    def __init__(self, in_features, hidden_features=None, out_features=None,act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2D(in_features, hidden_features, 1)
        self.act = act_layer()
        self.fc2 = nn.Conv2D(hidden_features, out_features, 1)
        self.drop = nn.Dropout(drop)    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)        return xclass Shift(nn.Layer):
    def __init__(self, n_div):
        super().__init__()
        self.n_div = n_div    def forward(self, x):
        B, C, H, W = x.shape
        g = C // self.n_div        # out = paddle.zeros_like(x)
        x[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:]  # shift left
        x[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1]  # shift right
        x[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :]  # shift up
        x[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :]  # shift down
        x[:, g * 4:, :, :] = x[:, g * 4:, :, :]  # no shift
        return xclass ShiftViTBlock(nn.Layer):
    def __init__(self, dim, n_div=12, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                input_resolution=None):
        super().__init__()
        self.input_resolution = input_resolution
        self.mlp_ratio = mlp_ratio
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim,
                       hidden_features=mlp_hidden_dim,
                       act_layer=act_layer,
                       drop=drop)
        self.n_div = n_div
        self.shift = Shift(n_div)    def forward(self, x):
        x = self.shift(x)
        shortcut = x
        x = shortcut + self.drop_path(self.mlp(self.norm(x.transpose([0, 2, 3, 1])).transpose([0, 3, 1, 2])))        return xclass BasicLayer(nn.Layer):
    def __init__(self, dim, input_resolution, depth, n_div=12, mlp_ratio=4., drop=0., drop_path=None, norm_layer=None, downsample=True,
                act_layer=nn.GELU):
        super(BasicLayer, self).__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth        # build blocks
        self.blocks = nn.LayerList([
            ShiftViTBlock(dim=dim,
                          n_div=n_div,
                          mlp_ratio=mlp_ratio,
                          drop=drop,
                          drop_path=drop_path[i],
                          norm_layer=norm_layer,
                          act_layer=act_layer,
                          input_resolution=input_resolution)            for i in range(depth)
        ])        # patch merging layer
        if downsample:
            self.downsample = nn.Sequential(
                nn.GroupNorm(num_groups=1, num_channels=dim),
                nn.Conv2D(dim, dim * 2, kernel_size=2, stride=2,bias_attr=False)
            )        else:
            self.downsample = None
    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)        if self.downsample is not None:
            x = self.downsample(x)        return xclass ShiftViT(nn.Layer):
    def __init__(self,n_div=12, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=(2, 2, 6, 2), mlp_ratio=2.,
                drop_rate=0., drop_path_rate=0.1, patch_norm=True, **kwargs):
        super().__init__()
        norm_layer = nn.LayerNorm
        act_layer = nn.GELU
        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio        # split image into non-overlapping patches
        self.patch_embed = nn.Sequential(
            nn.Conv2D(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size),
            Rearrange('b c h w->b h w c'),
            nn.LayerNorm(embed_dim) if self.patch_norm else nn.Identity(),
            Rearrange('b h w c->b c h w')
        )        # num_patches = self.patch_embed.num_patches
        patches_resolution = [img_size // patch_size, img_size // patch_size]
        self.patches_resolution = patches_resolution
        self.pos_drop = nn.Dropout(p=drop_rate)        # stochastic depth decay rule
        dpr = [x.item()               for x in paddle.linspace(0, drop_path_rate, sum(depths))]        # build layers
        self.layers = nn.LayerList()        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               n_div=n_div,
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               mlp_ratio=self.mlp_ratio,
                               drop=drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=(i_layer < self.num_layers - 1),
                               act_layer=act_layer)
            self.layers.append(layer)
        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool2D(1)
        self.head = nn.Linear(self.num_features, num_classes) \            if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)    def _init_weights(self, m):
        tn = nn.initializer.TruncatedNormal(std=.02)
        zeros = nn.initializer.Constant(0.)
        ones = nn.initializer.Constant(1.)        if isinstance(m, nn.Linear):
            tn(m.weight)            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros(m.bias)        elif isinstance(m, (nn.Conv1D, nn.Conv2D)):
            tn(m.weight)            if m.bias is not None:
                zeros(m.bias)        elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
            zeros(m.bias)
            ones(m.weight)    def forward_features(self, x):
        x = self.patch_embed(x)
        x = self.pos_drop(x)        for layer in self.layers:
            x = layer(x)
        x = self.norm(x.transpose([0, 2, 3, 1])).transpose([0, 3, 1, 2]) 
        x = self.avgpool(x) 
        x = paddle.flatten(x, 1)        return x    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)        return x# Shift-Tmodel = ShiftViT(n_div=12, embed_dim=96, depths=(6, 8, 18, 6), num_classes=10) paddle.summary(model, (1, 3, 224, 224))

# Shift-Smodel = ShiftViT(n_div=12, embed_dim=96, depths=(10, 18, 36, 10), num_classes=10) paddle.summary(model, (1, 3, 224, 224))

# Shift-Bmodel = ShiftViT(n_div=16, embed_dim=128, depths=(10, 18, 36, 10), num_classes=10) paddle.summary(model, (1, 3, 224, 224))

# Shift-oursmodel = ShiftViT(n_div=12, embed_dim=96, depths=(3, 4, 9, 3), num_classes=10) paddle.summary(model, (1, 3, 224, 224))

learning_rate = 0.001n_epochs = 100paddle.seed(42) np.random.seed(42)
work_path = 'work/model'# Shift-oursmodel = ShiftViT(n_div=12, embed_dim=96, depths=(3, 4, 9, 3), num_classes=10)
criterion = LabelSmoothingCrossEntropy()
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)
gate = 0.0threshold = 0.0best_acc = 0.0val_acc = 0.0loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording lossacc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracyloss_iter = 0acc_iter = 0for epoch in range(n_epochs):    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0
    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        logits = model(x_data)
        loss = criterion(logits, y_data)
        acc = accuracy_manager.compute(logits, labels)
        accuracy_manager.update(acc)        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()
        
        train_loss += loss
        train_num += len(y_data)
    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))    # ---------- Validation ----------
    model.eval()    for batch_id, data in enumerate(val_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)        with paddle.no_grad():
          logits = model(x_data)
        loss = criterion(logits, y_data)
        acc = val_accuracy_manager.compute(logits, labels)
        val_accuracy_manager.update(acc)
        val_loss += loss
        val_num += len(y_data)
    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)    
    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))
def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9
    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')
<Figure size 1000x600 with 1 Axes>
plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')
<Figure size 1000x600 with 1 Axes>
import time
work_path = 'work/model'model = ShiftViT(n_div=12, embed_dim=96, depths=(3, 4, 9, 3), num_classes=10)
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()for batch_id, data in enumerate(val_loader):
    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))Throughout:794
def get_cifar10_labels(labels):  
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',        'horse', 'ship', 'truck']    return [text_labels[int(i)] for i in labels]def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):  
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()    for i, (ax, img) in enumerate(zip(axes, imgs)):        if paddle.is_tensor(img):
            ax.imshow(img.numpy())        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])    return axeswork_path = 'work/model'X, y = next(iter(DataLoader(val_dataset, batch_size=18))) model = ShiftViT(n_div=12, embed_dim=96, depths=(3, 4, 9, 3), num_classes=10) model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams')) model.set_state_dict(model_state_dict) model.eval() logits = model(X) y_pred = paddle.argmax(logits, -1) X = paddle.transpose(X, [0, 2, 3, 1]) axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y)) plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<Figure size 2700x150 with 18 Axes>
| Model | Parameter | Val Acc | 
|---|---|---|
| ShiftViT | 14684554 | 0.93453 | 
| Swin | 14541500 | 0.86659 | 
注:Swin代码来自浅析 Swin Transformer,实验结果在main-Copy2.ipynb
本文用了一个简单的Shift操作证明了Vision Transformer中的Attention不是必要的,与Swin在参数可比的情况下精度高了0.06794(小数据集如CIFAR10上ShiftViT比Swin优势明显,在大数据集上ShiftViT与Swin性能差不多)

以上就是ShiftViT:采用简单高效的移位操作证明Attention是否必要的详细内容,更多请关注php中文网其它相关文章!
 
                        
                        每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
 
                Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号