PoolFormer: MetaFormer is Actually What You Need for Vision

Overview

PoolFormer: MetaFormer is Actually What You Need for Vision (arXiv)

This is a PyTorch implementation of PoolFormer proposed by our paper "MetaFormer is Actually What You Need for Vision".

MetaFormer

Figure 1: MetaFormer and performance of MetaFormer-based models on ImageNet-1K validation set. We argue that the competence of transformer/MLP-like models primarily stems from the general architecture MetaFormer instead of the equipped specific token mixers. To demonstrate this, we exploit an embarrassingly simple non-parametric operator, pooling, to conduct extremely basic token mixing. Surprisingly, the resulted model PoolFormer consistently outperforms the DeiT and ResMLP as shown in (b), which well supports that MetaFormer is actually what we need to achieve competitive performance.

PoolFormer Figure 2: (a) The overall framework of PoolFormer. (b) The architecture of PoolFormer block. Compared with transformer block, it replaces attention with an extremely simple non-parametric operator, pooling, to conduct only basic token mixing.

Bibtex

@article{yu2021metaformer,
  title={MetaFormer is Actually What You Need for Vision},
  author={Yu, Weihao and Luo, Mi and Zhou, Pan and Si, Chenyang and Zhou, Yichen and Wang, Xinchao and Feng, Jiashi and Yan, Shuicheng},
  journal={arXiv preprint arXiv:2111.11418},
  year={2021}
}

1. Requirements

For Image Classification (Configs of detection and segmentation will be available soon)

torch>=1.7.0; torchvision>=0.8.0; pyyaml; apex-amp (if you want to use fp16); timm (pip install git+https://github.com/rwightman/p[email protected])

data prepare: ImageNet with the following folder structure, you can extract ImageNet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Directory structure in this repo:

│poolformer/
├──misc/
├──models/
│  ├── __init__.py
│  ├── poolformer.py
├──LICENSE
├──README.md
├──distributed_train.sh
├──train.py
├──validate.py

2. PoolFormer Models

Model #params Image resolution Top1 Acc Download
poolformer_s12 12M 224 77.2 here
poolformer_s24 21M 224 80.3 here
poolformer_s36 31M 224 81.4 here
poolformer_m36 56M 224 82.1 here
poolformer_m48 73M 224 82.5 here

All the pretrained models can also be downloaded by BaiDu Yun (password: esac).

Update ResNet Scores in the paper

Updated_ResNet_Scores

[1] He et al., "Deep Residual Learning for Image Recognition", CVPR 2016.

[2] Wightman et al., "Resnet strikes back: An improved training procedure in timm", arXiv preprint arXiv:2110.00476. 2021 Oct 1.

Usage

We also provide a Colab notebook which run the steps to perform inference with poolformer.

3. Validation

To evaluate our PoolFormer models, run:

MODEL=poolformer_s12 # poolformer_{s12, s24, s36, m36, m48}
python3 validate.py /path/to/imagenet  --model $MODEL \
  --checkpoint /path/to/checkpoint -b 128

4. Train

We show how to train PoolFormers on 8 GPUs. The relation between learning rate and batch size is lr=bs/1024*1e-3. For convenience, assuming the batch size is 1024, then the learning rate is set as 1e-3 (for batch size of 1024, setting the learning rate as 2e-3 sometimes sees better performance).

MODEL=poolformer_s12 # poolformer_{s12, s24, s36, m36, m48}
DROP_PATH=0.1 # drop path rates [0.1, 0.1, 0.2, 0.3, 0.4] responding to model [s12, s24, s36, m36, m48]
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet \
  --model $MODEL -b 128 --lr 1e-3 --drop-path $DROP_PATH --apex-amp

5. Acknowledgment

Our implementation is mainly based on the following codebases. We gratefully thank the authors for their wonderful works.

pytorch-image-models, mmdetection, mmsegmentation.

Besides, Weihao Yu would like to thank TPU Research Cloud (TRC) program for the support of partial computational resources.

LICENSE

This repo is under the Apache-2.0 license. For commercial use, please contact the authors.

Comments
  • About Normalization

    About Normalization

    Hi, thanks for your excellent work. In your ablation studies (section 4.4), you compared Group Normalization (group number is set as 1 for simplicity), Layer Normalization, and Batch Normalization. The conclusion is that Group Normalization is 0.7% or 0.8% higher than Layer Normalization or Batch Normalization. But when the number of groups is 1, Group Normalization is equivalent to Layer Normalization, right?

    opened by tinyalpha 10
  • Addition of the Organization on HuggingFace Transformers

    Addition of the Organization on HuggingFace Transformers

    Hello PoolFormer team!

    I have been working on porting the implementation of PoolFormer to HuggingFace Transformers library (you can see my PR here) and I was wondering if I can go ahead and add Sea AI labs as an organization to the HuggingFace models hub.

    This will allow all model checkpoints to be uploaded onto the hub as well as model cards, etc.

    Kind regards, Tanay Mehta

    opened by heytanay 7
  • How to measure MACs?

    How to measure MACs?

    Hi, thanks for your nice work :) I also watched your presentation record through this conference.

    I want to apply the poolformer for my work, can I ask how did you measure the MACs of the architecture introduced in your paper? Or if you were not bothered, I want to ask if I could be shared your measurement code.

    opened by DoranLyong 5
  • why use use_layer_scale

    why use use_layer_scale

    thanks for your great contribution! in the implement for poolformerblock ,there is a layer_scale after token_mixer. What is the impact of this operation?

    opened by rtfgithub 5
  • Invitation of making PR for OpenMMLab / MMSegmentation.

    Invitation of making PR for OpenMMLab / MMSegmentation.

    Hi, first congrats for acceptance of CVPR'2022. This work deserves because it is very great.

    I am a member of OpenMMLab and mainly work for developing MMSegmentation. I think if it supported officially, many more people would use it for benchmark, which would promote research in computer vision area.

    Would you like to make PR for openmmlab? We could discuss together to refactor your code and use our own GPUs to train & re-implement.

    I think it is pretty cool because it would make more reseachers and community members use this excellent work! Here is our re-implementing work: ConvNeXt.

    We do hope PoolFormer could also be added as backbones in our codebase so that many researchers could use directly it for downstream tasks.

    Looking forward to your reply!

    Best,

    opened by MengzhangLI 5
  • why the speed slower than pvtv2-b1?

    why the speed slower than pvtv2-b1?

    Recently I trained a transformer based instance seg model, tested with different backbone, here is the result and speed test:

    image

    batchsize is training batchsize. Why the speed of poolformer is the slowest one? is that normal?

    Slower than pvtv2-b1 and precision less than it...

    opened by jinfagang 5
  • Checkpoints of the Ablation study

    Checkpoints of the Ablation study

    Hi, thanks for your amazing work. I am reading the Tab 6, and I am surprised because the method is so simple and very effective, especially when the Pooling is replaced with Identity Mapping. Top1 74.3 on ImageNet-1k with only Conv1x1 and Norm layer. I am thrilled... Can you release this checkpoint so that we can verify. Thanks again. image

    opened by chuong98 5
  • Design on positional embedding?

    Design on positional embedding?

    Hello authors,

    I appreciate a lot your current work, which inspired the community. I am here to raise a very simple and quick question after checking the code and architecture design.

    I observed that network using pooling, MLP or identical as token mixer, you do not include positional embedding, while you consider this component only when you use MHA. What is the concern of this design and why other models do not rely on this embedding?

    Best,

    discussion 
    opened by jizongFox 4
  • Error: About self.pool(x)

    Error: About self.pool(x)

    Hello, I am more interested in the poolformer you proposed, but an error occurred during the use of PoolFormerBlock, as follows: Traceback (most recent call last): File "train.py", line 545, in train(hyp, opt, device, tb_writer) File "train.py", line 89, in train model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create File "E:\Work\yolov5\models\yolo.py", line 106, in init m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward File "E:\Work\yolov5\models\yolo.py", line 138, in forward return self.forward_once(x, profile) # single-scale inference, train File "E:\Work\yolov5\models\yolo.py", line 157, in forward_once x = m(x) # run # 执行网络组件操作 File "C:\conda\conda\envs\torch17\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "E:\Work\yolov5_T23\models\common.py", line 194, in forward n = self.token_mixer(m) File "C:\conda\conda\envs\torch17\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "E:\Work\yolov5_T23\models\Confor_VC.py", line 93, in forward x1 = self.pool(x) - x # x1 = self.pool(x) - x File "C:\conda\conda\envs\torch17\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "C:\conda\conda\envs\torch17\lib\site-packages\torch\nn\modules\pooling.py", line 594, in forward return F.avg_pool2d(input, self.kernel_size, self.stride, TypeError: avg_pool2d(): argument 'kernel_size' (position 2) must be tuple of ints, not bool

    I want to put the poolformer behind a ConvBlock and the above problem occurred。 thank you!

    opened by QY1994-0919 4
  • About MLN(Modified Layer Normalization)

    About MLN(Modified Layer Normalization)

    This paper provides new perspectives about Transformer block, but I have some questions about one of the details. As far as I know, the LayerNorm officially provided by Pytorch implements the same function as the MLN, which computes the mean and variance along token and channel dimensions. So where is the improvement? image The official example : #Image Example N, C, H, W = 20, 5, 10, 10 input = torch.randn(N, C, H, W) #Normalize over the last three dimensions (i.e. the channel and spatial dimensions) #as shown in the image below layer_norm = nn.LayerNorm([C, H, W]) output = layer_norm(input)

    opened by youngtboy 3
  • How to achieve the grad-CAM visualization?

    How to achieve the grad-CAM visualization?

    Thanks for your awesome work and for sharing them all.

    I found out that the pictures in the supplement paper are beautiful, and I want to follow this.

    Could you share the code for this? or can tell me how to achieve the grad-CAM activation map?

    opened by DoranLyong 3
  • CVE-2007-4559 Patch

    CVE-2007-4559 Patch

    Patching CVE-2007-4559

    Hi, we are security researchers from the Advanced Research Center at Trellix. We have began a campaign to patch a widespread bug named CVE-2007-4559. CVE-2007-4559 is a 15 year old bug in the Python tarfile package. By using extract() or extractall() on a tarfile object without sanitizing input, a maliciously crafted .tar file could perform a directory path traversal attack. We found at least one unsantized extractall() in your codebase and are providing a patch for you via pull request. The patch essentially checks to see if all tarfile members will be extracted safely and throws an exception otherwise. We encourage you to use this patch or your own solution to secure against CVE-2007-4559. Further technical information about the vulnerability can be found in this blog.

    If you have further questions you may contact us through this projects lead researcher Kasimir Schulz.

    opened by TrellixVulnTeam 0
  • On the use of Apex AMP and hybrid stages

    On the use of Apex AMP and hybrid stages

    Is there a specific reason why you used Apex AMP instead of the native AMP provided by PyTorch? Have you tried native AMP?

    I tried to train poolformer_s12 and poolformer_s24 with solo-learn; with native fp16 the loss goes to nan after a few epochs, while with fp32 it works fine. Did you experience similar behavior?

    On a side note, can you provide the implementation and the hyperparameters for the hybrid stage [Pool, Pool, Attention, Attention]? It seems very interesting!

    discussion 
    opened by DonkeyShot21 6
  • Can I say PoolFormer is just a non-trainable MLP-like module?

    Can I say PoolFormer is just a non-trainable MLP-like module?

    Hi! Thanks for sharing the great work! I have some questions about PoolFormer. If I explain PoolFormer like the following attachments, can I say PoolFormer is just a non-trainable MLP-like model?

    image image

    discussion 
    opened by 072jiajia 8
  • About subtract in pooling

    About subtract in pooling

    Hi, thank you for publishing such a nice paper. I just have one question. I do not understand the subtraction of the input in eqn.4. Is it necessary? What will happen if we just do the average pooling without substrating the input?

    discussion 
    opened by Dong-Huo 16
Owner
Sea AI Lab
Sea AI Lab
Code for the paper "Reinforced Active Learning for Image Segmentation"

Reinforced Active Learning for Image Segmentation (RALIS) Code for the paper Reinforced Active Learning for Image Segmentation Dependencies python 3.6

Arantxa Casanova 79 Dec 19, 2022
Implementation of CVPR 2020 Dual Super-Resolution Learning for Semantic Segmentation

Dual super-resolution learning for semantic segmentation 2021-01-02 Subpixel Update Happy new year! The 2020-12-29 update of SISR with subpixel conv p

Sam 79 Nov 24, 2022
Unofficial PyTorch Implementation of AHDRNet (CVPR 2019)

AHDRNet-PyTorch This is the PyTorch implementation of Attention-guided Network for Ghost-free High Dynamic Range Imaging (CVPR 2019). The official cod

Yutong Zhang 4 Sep 08, 2022
Research using Cirq!

ReCirq Research using Cirq! This project contains modules for running quantum computing applications and experiments through Cirq and Quantum Engine.

quantumlib 230 Dec 29, 2022
Mixed Neural Likelihood Estimation for models of decision-making

Mixed neural likelihood estimation for models of decision-making Mixed neural likelihood estimation (MNLE) enables Bayesian parameter inference for mo

mackelab 9 Dec 22, 2022
DeepFill v1/v2 with Contextual Attention and Gated Convolution, CVPR 2018, and ICCV 2019 Oral

Generative Image Inpainting An open source framework for generative image inpainting task, with the support of Contextual Attention (CVPR 2018) and Ga

2.9k Dec 16, 2022
Code release for DS-NeRF (Depth-supervised Neural Radiance Fields)

Depth-supervised NeRF: Fewer Views and Faster Training for Free Project | Paper | YouTube Pytorch implementation of our method for learning neural rad

524 Jan 08, 2023
Apache Flink

Apache Flink Apache Flink is an open source stream processing framework with powerful stream- and batch-processing capabilities. Learn more about Flin

The Apache Software Foundation 20.4k Dec 30, 2022
MoViNets PyTorch implementation: Mobile Video Networks for Efficient Video Recognition;

MoViNet-pytorch Pytorch unofficial implementation of MoViNets: Mobile Video Networks for Efficient Video Recognition. Authors: Dan Kondratyuk, Liangzh

189 Dec 20, 2022
This repository provides an unified frameworks to train and test the state-of-the-art few-shot font generation (FFG) models.

FFG-benchmarks This repository provides an unified frameworks to train and test the state-of-the-art few-shot font generation (FFG) models. What is Fe

Clova AI Research 101 Dec 27, 2022
SHIFT15M: multiobjective large-scale fashion dataset with distributional shifts

[arXiv] The main motivation of the SHIFT15M project is to provide a dataset that contains natural dataset shifts collected from a web service IQON, wh

ZOZO, Inc. 138 Nov 24, 2022
scAR (single-cell Ambient Remover) is a package for data denoising in single-cell omics.

scAR scAR (single cell Ambient Remover) is a package for denoising multiple single cell omics data. It can be used for multiple tasks, such as, sgRNA

19 Nov 28, 2022
This is the repository for our paper SimpleTrack: Understanding and Rethinking 3D Multi-object Tracking

SimpleTrack This is the repository for our paper SimpleTrack: Understanding and Rethinking 3D Multi-object Tracking. We are still working on writing t

TuSimple 189 Dec 26, 2022
A lightweight tool to get an AI Infrastructure Stack up in minutes not days.

K3ai will take care of setup K8s for You, deploy the AI tool of your choice and even run your code on it.

k3ai 105 Dec 04, 2022
Python Implementation of algorithms in Graph Mining, e.g., Recommendation, Collaborative Filtering, Community Detection, Spectral Clustering, Modularity Maximization, co-authorship networks.

Graph Mining Author: Jiayi Chen Time: April 2021 Implemented Algorithms: Network: Scrabing Data, Network Construbtion and Network Measurement (e.g., P

Jiayi Chen 3 Mar 03, 2022
Six - a Python 2 and 3 compatibility library

Six is a Python 2 and 3 compatibility library. It provides utility functions for smoothing over the differences between the Python versions with the g

Benjamin Peterson 919 Dec 28, 2022
A Loss Function for Generative Neural Networks Based on Watson’s Perceptual Model

This repository contains the similarity metrics designed and evaluated in the paper, and instructions and code to re-run the experiments. Implementation in the deep-learning framework PyTorch

Steffen 86 Dec 27, 2022
Official implementation for the paper "Attentive Prototypes for Source-free Unsupervised Domain Adaptive 3D Object Detection"

Attentive Prototypes for Source-free Unsupervised Domain Adaptive 3D Object Detection PyTorch code release of the paper "Attentive Prototypes for Sour

Deepti Hegde 23 Oct 17, 2022
PyTorch implementation of federated learning framework based on the acceleration of global momentum

Federated Learning with Acceleration of Global Momentum PyTorch implementation of federated learning framework based on the acceleration of global mom

0 Dec 23, 2021
An implementation of the efficient attention module.

Efficient Attention An implementation of the efficient attention module. Description Efficient attention is an attention mechanism that substantially

Shen Zhuoran 194 Dec 15, 2022