Implementation of ResMLP, an all MLP solution to image classification, in Pytorch

Overview

ResMLP - Pytorch

Implementation of ResMLP, an all MLP solution to image classification out of Facebook AI, in Pytorch

Install

$ pip install res-mlp-pytorch

Usage

import torch
from res_mlp_pytorch import ResMLP

model = ResMLP(
    image_size = 256,
    patch_size = 16,
    dim = 512,
    depth = 12,
    num_classes = 1000
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

Citations

@misc{touvron2021resmlp,
    title   = {ResMLP: Feedforward networks for image classification with data-efficient training}, 
    author  = {Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
    year    = {2021},
    eprint  = {2105.03404},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch
Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch

Segformer - Pytorch Implementation of Segformer, Attention + MLP neural network for segmentation, in Pytorch. Install $ pip install segformer-pytorch

🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐
🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

Pytorch implementation of MLP-Mixer with loading pre-trained models.

MLP-Mixer-Pytorch PyTorch implementation of MLP-Mixer: An all-MLP Architecture for Vision with the function of loading official ImageNet pre-trained p

Image Classification - A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

MLP-Like Vision Permutator for Visual Recognition (PyTorch)
MLP-Like Vision Permutator for Visual Recognition (PyTorch)

Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition (arxiv) This is a Pytorch implementation of our paper. We present Vision

Xview3 solution - XView3 challenge, 2nd place solution
Xview3 solution - XView3 challenge, 2nd place solution

Xview3, 2nd place solution https://iuu.xview.us/ test split aggregate score publ

Unofficial Implementation of MLP-Mixer in TensorFlow
Unofficial Implementation of MLP-Mixer in TensorFlow

mlp-mixer-tf Unofficial Implementation of MLP-Mixer [abs, pdf] in TensorFlow. Note: This project may have some bugs in it. I'm still learning how to i

Implementation of
Implementation of "A MLP-like Architecture for Dense Prediction"

A MLP-like Architecture for Dense Prediction (arXiv) Updates (22/07/2021) Initial release. Model Zoo We provide CycleMLP models pretrained on ImageNet

MLP-Numpy - A simple modular implementation of Multi Layer Perceptron in pure Numpy.

MLP-Numpy A simple modular implementation of Multi Layer Perceptron in pure Numpy. I used the Iris dataset from scikit-learn library for the experimen

Comments
  • torch dataset example

    torch dataset example

    I wrote this examples with a data loader:

    import os
    import natsort
    from PIL import Image
    import torch
    import torchvision.transforms as T
    from res_mlp_pytorch.res_mlp_pytorch import ResMLP
    
    class LPCustomDataSet(torch.utils.data.Dataset):
        '''
            Naive Torch Image Dataset Loader
            with support for Image loading errors
            and Image resizing
        '''
        def __init__(self, main_dir, transform):
            self.main_dir = main_dir
            self.transform = transform
            all_imgs = os.listdir(main_dir)
            self.total_imgs = natsort.natsorted(all_imgs)
    
        def __len__(self):
            return len(self.total_imgs)
    
        def __getitem__(self, idx):
            img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
            try:
                image = Image.open(img_loc).convert("RGB")
                tensor_image = self.transform(image)
                return tensor_image
            except:
                pass
                return None
    
        @classmethod
        def collate_fn(self, batch):
            '''
                Collate filtering not None images
            '''
            batch = list(filter(lambda x: x is not None, batch))
            return torch.utils.data.dataloader.default_collate(batch)
    
        @classmethod
        def transform(self,img):
            '''
                Naive image resizer
            '''
            transform = T.Compose([
                T.Resize(256),
                T.CenterCrop(224),
                T.ToTensor(),
                T.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])
            return transform(img)
    

    to feed ResMLP:

    model = ResMLP(
        image_size = 256,
        patch_size = 16,
        dim = 512,
        depth = 12,
        num_classes = 1000
    )
    batch_size = 2
    my_dataset = LPCustomDataSet(os.path.join(os.path.dirname(
        os.path.abspath(__file__)), 'data'), transform=LPCustomDataSet.transform)
    train_loader = torch.utils.data.DataLoader(my_dataset , batch_size=batch_size, shuffle=False, 
                                   num_workers=4, drop_last=True, collate_fn=LPCustomDataSet.collate_fn)
    for idx, img in enumerate(train_loader):
        pred = model(img) # (1, 1000)
        print(idx, img.shape, pred.shape
    

    But I get this error

    RuntimeError: Given groups=1, weight of size [256, 256, 1], expected input[1, 196, 512] to have 256 channels, but got 196 channels instead
    

    not sure if LPCustomDataSet.transform has the correct for the input image

    opened by loretoparisi 3
  • add dropout and CIFAR100 example notebook

    add dropout and CIFAR100 example notebook

    • According to ResMLP paper, it appears that dropout layer has been implemented in Machine translation when using ResMLP.
    We use Adagrad with learning rate 0.2, 32k steps of linear warmup, label smoothing 0.1, dropout rate 0.15 for En-De and 0.1 for En-Fr.
    
    • Since MLP literatures often mention that MLP is susceptible to overfitting, which is one of the reason why weight decay is so high, implementing dropout will be reasonable choice of regularization.

    Open in Colab | 🔗 Wandb Log

    • Above is my simple experimentation on CIFAR100 dataset, with three different dropout rates: [0.0, 0.25, 0.5].
    • Higher dropout yielded better test metrics(loss, acc1 and acc5).
    opened by snoop2head 0
  • What learning rate/scheduler/optimizer are suitable for training mlp-mixer?

    What learning rate/scheduler/optimizer are suitable for training mlp-mixer?

    Thanks for your codes!

    I find it is very important to set suitable lr/scheduler/optimizer for training res-mlp models. In my experiments with a small dataset, the classification performance is very poor when I train models with lr=1e-3 or 1e-4, weight-decay=05e-4, scheduler=WarmupCosineLrScheduler, optim='sgd'. The results increase remarkably when lr=5e-3, weight-decay=0.2, scheduler=WarmupCosineLrScheduler, optim='lamb'.

    While the results are still much lower than CNN models with comparable params. trained from scratch. Could you provide any suggestions for training res-mlp?

    opened by QiushiYang 0
Releases(0.0.6)
Owner
Phil Wang
Working with Attention.
Phil Wang
Gas detection for Raspberry Pi using ADS1x15 and MQ-2 sensors

Gas detection Gas detection for Raspberry Pi using ADS1x15 and MQ-2 sensors. Description The MQ-2 sensor can detect multiple gases (CO, H2, CH4, LPG,

Filip Š 15 Sep 30, 2022
The description of FMFCC-A (audio track of FMFCC) dataset and Challenge resluts.

FMFCC-A This project is the description of FMFCC-A (audio track of FMFCC) dataset and Challenge resluts. The FMFCC-A dataset is shared through BaiduCl

18 Dec 24, 2022
Dense Unsupervised Learning for Video Segmentation (NeurIPS*2021)

Dense Unsupervised Learning for Video Segmentation This repository contains the official implementation of our paper: Dense Unsupervised Learning for

Visual Inference Lab @TU Darmstadt 173 Dec 26, 2022
Internship Assessment Task for BaggageAI.

BaggageAI Internship Task Problem Statement: You are given two sets of images:- background and threat objects. Background images are the background x-

Arya Shah 10 Nov 14, 2022
Simple torch.nn.module implementation of Alias-Free-GAN style filter and resample

Alias-Free-Torch Simple torch module implementation of Alias-Free GAN. This repository including Alias-Free GAN style lowpass sinc filter @filter.py A

이준혁(Junhyeok Lee) 64 Dec 22, 2022
Official implementation for Scale-Aware Neural Architecture Search for Multivariate Time Series Forecasting

1 SNAS4MTF This repo is the official implementation for Scale-Aware Neural Architecture Search for Multivariate Time Series Forecasting. 1.1 The frame

SZJ 5 Sep 21, 2022
Code for KHGT model, AAAI2021

KHGT Code for KHGT accepted by AAAI2021 Please unzip the data files in Datasets/ first. To run KHGT on Yelp data, use python labcode_yelp.py For Movi

32 Nov 29, 2022
ONNX-GLPDepth - Python scripts for performing monocular depth estimation using the GLPDepth model in ONNX

ONNX-GLPDepth - Python scripts for performing monocular depth estimation using the GLPDepth model in ONNX

Ibai Gorordo 18 Nov 06, 2022
ONNX-PackNet-SfM: Python scripts for performing monocular depth estimation using the PackNet-SfM model in ONNX

Python scripts for performing monocular depth estimation using the PackNet-SfM model in ONNX

Ibai Gorordo 14 Dec 09, 2022
PyTorch implementation of the TTC algorithm

Trust-the-Critics This repository is a PyTorch implementation of the TTC algorithm and the WGAN misalignment experiments presented in Trust the Critic

0 Nov 29, 2021
Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR 2022)

Voxel Set Transformer: A Set-to-Set Approach to 3D Object Detection from Point Clouds (CVPR2022)[paper] Authors: Chenhang He, Ruihuang Li, Shuai Li, L

Billy HE 141 Dec 30, 2022
A high performance implementation of HDBSCAN clustering.

HDBSCAN HDBSCAN - Hierarchical Density-Based Spatial Clustering of Applications with Noise. Performs DBSCAN over varying epsilon values and integrates

2.3k Jan 02, 2023
Patch-Diffusion Code (AAAI2022)

Patch-Diffusion This is an official PyTorch implementation of "Patch Diffusion: A General Module for Face Manipulation Detection" in AAAI2022. Require

H 7 Nov 02, 2022
Model-based reinforcement learning in TensorFlow

Bellman Website | Twitter | Documentation (latest) What does Bellman do? Bellman is a package for model-based reinforcement learning (MBRL) in Python,

46 Nov 09, 2022
Unofficial implementation of MUSIQ (Multi-Scale Image Quality Transformer)

MUSIQ: Multi-Scale Image Quality Transformer Unofficial pytorch implementation of the paper "MUSIQ: Multi-Scale Image Quality Transformer" (paper link

41 Jan 02, 2023
SpineAI Bilsky Grading With Python

SpineAI-Bilsky-Grading SpineAI Paper with Code 📫 Contact Address correspondence to J.T.P.D.H. (e-mail: james_hallinan AT nuhs.edu.sg) Disclaimer This

<a href=[email protected]"> 2 Dec 16, 2021
Code for "Graph-Evolving Meta-Learning for Low-Resource Medical Dialogue Generation". [AAAI 2021]

Graph Evolving Meta-Learning for Low-resource Medical Dialogue Generation Code to be further cleaned... This repo contains the code of the following p

Shuai Lin 29 Nov 01, 2022
PyTorch implementation of spectral graph ConvNets, NIPS’16

Graph ConvNets in PyTorch October 15, 2017 Xavier Bresson http://www.ntu.edu.sg/home/xbresson https://github.com/xbresson https://twitter.com/xbresson

Xavier Bresson 287 Jan 04, 2023
AOT-GAN for High-Resolution Image Inpainting (codebase for image inpainting)

AOT-GAN for High-Resolution Image Inpainting Arxiv Paper | AOT-GAN: Aggregated Contextual Transformations for High-Resolution Image Inpainting Yanhong

Multimedia Research 214 Jan 03, 2023
PyTorch implementation of SmoothGrad: removing noise by adding noise.

SmoothGrad implementation in PyTorch PyTorch implementation of SmoothGrad: removing noise by adding noise. Vanilla Gradients SmoothGrad Guided backpro

SSKH 143 Jan 05, 2023