Exploring whether attention is necessary for vision transformers

Overview

Do You Even Need Attention?
A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet

Paper/Report

TL;DR

We replace the attention layer in a vision transformer with a feed-forward layer and find that it still works quite well on ImageNet.

Abstract

The strong performance of vision transformers on image classification and other vision tasks is often attributed to the design of their multi-head attention layers. However, the extent to which attention is responsible for this strong performance remains unclear. In this short report, we ask: is the attention layer even necessary? Specifically, we replace the attention layer in a vision transformer with a feed-forward layer applied over the patch dimension. The resulting architecture is simply a series of feed-forward layers applied over the patch and feature dimensions in an alternating fashion. In experiments on ImageNet, this architecture performs surprisingly well: a ViT/DeiT-base-sized model obtains 74.9% top-1 accuracy, compared to 77.9% and 79.9% for ViT and DeiT respectively. These results indicate that aspects of vision transformers other than attention, such as the patch embedding, may be more responsible for their strong performance than previously thought. We hope these results prompt the community to spend more time trying to understand why our current models are as effective as they are.

Note

This is concurrent research with MLP-Mixer from Google Research. The ideas are exacty the same, with the one difference being that they use (a lot) more compute.

Pretrained models and logs

Here is a Weights and Biases report with the expected training trajectory: W&B

name [email protected] #params url
FF-tiny 61.4 7.7M model
FF-base 74.9 62M model
FF-large 71.4 206M -

Note: I haven't uploaded the FF-Large model because (1) it's over GitHub's file storage limit, and (2) I don't see why anyone would want it, given that it performs worse than the base model. That being said, if you want it, reach out to me and I'll send it to you.

How to train

The model definition in vision_transformer_linear.py is designed to be run with the repo from DeiT, which is itself based on the wonderful timm package.

Steps:

  • Clone the DeiT repo and move the file into it
git clone https://github.com/facebookresearch/deit
mv vision_transformer_linear.py deit
cd deit
  • Add a line to import vision_transformer_linear in main.py. For example, add the following after the import statements (around line 27):
+ import vision_transformer_linear
  • Train:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \
--nproc_per_node=8 \
--master_port 10490 \
--use_env main.py \
--model linear_tiny \
--batch-size 128 \
--drop 0.1 \
--output_dir outputs/linear-tiny \
--data-path your/path/to/imagenet

Citation

If you build upon this idea, feel free to drop a citation (and also cite MLP-Mixer).

@article{melaskyriazi2021doyoueven,
  title={Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet},
  author={Luke Melas-Kyriazi},
  journal=arxiv,
  year=2021
}
You might also like...
The first dataset of composite images with rationality score indicating whether the object placement in a composite image is reasonable.
The first dataset of composite images with rationality score indicating whether the object placement in a composite image is reasonable.

Object-Placement-Assessment-Dataset-OPA Object-Placement-Assessment (OPA) is to verify whether a composite image is plausible in terms of the object p

Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network.
Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network.

face-mask-detection Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network. It contains 3 scr

This program was designed to detect whether someone is wearing a facemask through a live video stream.

This program was designed to detect whether someone is wearing a facemask through a live video stream. A custom lightweight CNN trained with TensorFlow on a public dataset provided by Kaggle is used to detect whether each face detected by the cv2 face detection dnn is wearing a mask

A system used to detect whether a person is wearing a medical mask or not.

Mask_Detection_System A system used to detect whether a person is wearing a medical mask or not. To open the program, please follow these steps: Make

Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Multivariate Time Series Forecasting with efficient Transformers. Code for the paper
Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."

Spacetimeformer Multivariate Forecasting This repository contains the code for the paper, "Long-Range Transformers for Dynamic Spatiotemporal Forecast

Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch.

SE3 Transformer - Pytorch Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 resu

Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.
Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.

PyTorch Implementation of Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers 1 Using Colab Please notic

Comments
  • On details about the experiments

    On details about the experiments

    In the experiment section of the report, you mention:

    Such a comparison is not exactly fair because the feed-forward model uses stronger training augmentations.

    I wonder what the augmentations are, for details about the augmentation seems missing from the paper.

    opened by boredtylin 2
  • Positional Encoding Ablation

    Positional Encoding Ablation

    Hi Luke, thank you for sharing this amazing work.

    In your arxiv document, I cannot find any mention of positional encoding, but I see that you use them in your code. Did you conduct any ablation study on the PE? i.e., how much does it affect the performance, with and without?

    Thank you in advance.

    opened by jjparkcv 0
  • Interaction between patches through a transpose may have a stronger role to play ?

    Interaction between patches through a transpose may have a stronger role to play ?

    Hi, I was going through your exp report. You have made a point that since you are able to get a good performance without using attention layer so good performance of ViT may be more to do with it's embedding layer than attention .

    But I believe, It's also may be to do with how you have established an interaction between patches through a transpose very similar to what was done in MLP-Mixer .

    Would love to know your thoughts on this ?

    opened by rakshith291 0
Owner
Luke Melas-Kyriazi
I'm student at Harvard University studying mathematics and computer science, always open to collaborate on interesting projects!
Luke Melas-Kyriazi
PyTorch implementation of the method described in the paper VoiceLoop: Voice Fitting and Synthesis via a Phonological Loop.

VoiceLoop PyTorch implementation of the method described in the paper VoiceLoop: Voice Fitting and Synthesis via a Phonological Loop. VoiceLoop is a n

Meta Archive 873 Dec 15, 2022
OpenPCDet Toolbox for LiDAR-based 3D Object Detection.

OpenPCDet OpenPCDet is a clear, simple, self-contained open source project for LiDAR-based 3D object detection. It is also the official code release o

OpenMMLab 3.2k Dec 31, 2022
Jax/Flax implementation of Variational-DiffWave.

jax-variational-diffwave Jax/Flax implementation of Variational-DiffWave. (Zhifeng Kong et al., 2020, Diederik P. Kingma et al., 2021.) DiffWave with

YoungJoong Kim 37 Dec 16, 2022
A PyTorch Image-Classification With AlexNet And ResNet50.

PyTorch 图像分类 依赖库的下载与安装 在终端中执行 pip install -r -requirements.txt 完成项目依赖库的安装 使用方式 数据集的准备 STL10 数据集 下载:STL-10 Dataset 存储位置:将下载后的数据集中 train_X.bin,train_y.b

FYH 4 Feb 22, 2022
PyTorch code for the paper "Complementarity is the King: Multi-modal and Multi-grained Hierarchical Semantic Enhancement Network for Cross-modal Retrieval".

Complementarity is the King: Multi-modal and Multi-grained Hierarchical Semantic Enhancement Network for Cross-modal Retrieval (M2HSE) PyTorch code fo

Xinlei-Pei 6 Dec 23, 2022
Learning Generative Models of Textured 3D Meshes from Real-World Images, ICCV 2021

Learning Generative Models of Textured 3D Meshes from Real-World Images This is the reference implementation of "Learning Generative Models of Texture

Dario Pavllo 115 Jan 07, 2023
基于Paddlepaddle复现yolov5,支持PaddleDetection接口

PaddleDetection yolov5 https://github.com/Sharpiless/PaddleDetection-Yolov5 简介 PaddleDetection飞桨目标检测开发套件,旨在帮助开发者更快更好地完成检测模型的组建、训练、优化及部署等全开发流程。 PaddleD

36 Jan 07, 2023
Skyformer: Remodel Self-Attention with Gaussian Kernel and Nystr\"om Method (NeurIPS 2021)

Skyformer This repository is the official implementation of Skyformer: Remodel Self-Attention with Gaussian Kernel and Nystr"om Method (NeurIPS 2021).

Qi Zeng 46 Sep 20, 2022
Tackling data scarcity in Speech Translation using zero-shot multilingual Machine Translation techniques

Tackling data scarcity in Speech Translation using zero-shot multilingual Machine Translation techniques This repository is derived from the NMTGMinor

Tu Anh Dinh 1 Sep 07, 2022
Space Time Recurrent Memory Network - Pytorch

Space Time Recurrent Memory Network - Pytorch (wip) Implementation of Space Time Recurrent Memory Network, recurrent network competitive with attentio

Phil Wang 50 Nov 07, 2021
Code release for ICCV 2021 paper "Anticipative Video Transformer"

Anticipative Video Transformer Ranked first in the Action Anticipation task of the CVPR 2021 EPIC-Kitchens Challenge! (entry: AVT-FB-UT) [project page

Facebook Research 123 Dec 13, 2022
A Deep Reinforcement Learning Framework for Stock Market Trading

DQN-Trading This is a framework based on deep reinforcement learning for stock market trading. This project is the implementation code for the two pap

61 Jan 01, 2023
Some pvbatch (paraview) scripts for postprocessing OpenFOAM data

pvbatchForFoam Some pvbatch (paraview) scripts for postprocessing OpenFOAM data For every script there is a help message available: pvbatch pv_state_s

Morev Ilya 2 Oct 26, 2022
A PyTorch Toolbox for Face Recognition

FaceX-Zoo FaceX-Zoo is a PyTorch toolbox for face recognition. It provides a training module with various supervisory heads and backbones towards stat

JDAI-CV 1.6k Jan 06, 2023
[ICSE2020] MemLock: Memory Usage Guided Fuzzing

MemLock: Memory Usage Guided Fuzzing This repository provides the tool and the evaluation subjects for the paper "MemLock: Memory Usage Guided Fuzzing

Cheng Wen 54 Jan 07, 2023
This is just a funny project that we want to see AutoEncoder (AE) can actually work to enhance the features we want

Funny_muscle_enhancer :) 1.Discription: This is just a funny project that we want to see AutoEncoder (AE) can actually work on the some features. We w

Jing-Yao Chen (Jacob) 8 Oct 01, 2022
Pre-trained models for a Cascaded-FCN in caffe and tensorflow that segments

Cascaded-FCN This repository contains the pre-trained models for a Cascaded-FCN in caffe and tensorflow that segments the liver and its lesions out of

300 Nov 22, 2022
Continuum Learning with GEM: Gradient Episodic Memory

Gradient Episodic Memory for Continual Learning Source code for the paper: @inproceedings{GradientEpisodicMemory, title={Gradient Episodic Memory

Facebook Research 360 Dec 27, 2022
HODEmu, is both an executable and a python library that is based on Ragagnin 2021 in prep.

HODEmu HODEmu, is both an executable and a python library that is based on Ragagnin 2021 in prep. and emulates satellite abundance as a function of co

Antonio Ragagnin 1 Oct 13, 2021
Anti-Adversarially Manipulated Attributions for Weakly and Semi-Supervised Semantic Segmentation (CVPR 2021)

Anti-Adversarially Manipulated Attributions for Weakly and Semi-Supervised Semantic Segmentation Input Image Initial CAM Successive Maps with adversar

Jungbeom Lee 110 Dec 07, 2022