Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Overview

gMLP - Pytorch

Implementation of gMLP, an all-MLP replacement for Transformers, in Pytorch

Install

$ pip install g-mlp-pytorch

Usage

For masked language modelling

import torch
from g_mlp_pytorch import gMLP

model = gMLP(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 256
)

x = torch.randint(0, 20000, (1, 256))
emb = model(x) # (1, 256, 512)

For image classification

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

You can also add a tiny amount of attention (one-headed) to boost performance, as mentioned in the paper as aMLP, with the addition of one extra keyword attn_dim. This applies to both gMLPVision and gMLP

import torch
from g_mlp_pytorch import gMLPVision

model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 512,
    depth = 6,
    attn_dim = 64
)

img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

Citations

@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Custom image sizes?

    Custom image sizes?

    Hi, Thanks for your great (and very fast) contribution! I was wondering if you could help me figure out how to apply this to a different image size? It's not really an image, but rather a 2D dimensional tensor of 4096X100.

    I saw that I can change the number of channels, so I could just set channels to be 1. But I see that firstly - your implementation is for squared images, and secondly, it requires that image size should be devisable by patch size.

    Since you've written this implementation perhaps you could help me to adapt it for my needs? (and maybe other users for their cases).

    Maybe I could pad the length to be 128 so both would be devisable by 16 for example? but then where do I set different h, w ?

    Thanks.

    opened by danarte 3
  • Parameter count doesnt line up with paper

    Parameter count doesnt line up with paper

    Just a note (and correct me if I misunderstood the paper) -

    The parameter count for the Tiny gMLP doesnt line up with the param count from the paper for 30 layers and 128 dim and 6 ff_mult. Thats probably due to the doubling of parameters here - https://github.com/lucidrains/g-mlp-pytorch/blob/main/g_mlp_pytorch/g_mlp_pytorch.py#L111

    Halving this back to dim_ff + all 3 lines here need to halve their respective dims - https://github.com/lucidrains/g-mlp-pytorch/blob/main/g_mlp_pytorch/g_mlp_pytorch.py#L64-L66

    Then param count is roughly 5.5 M params.

    opened by titu1994 2
  • Add Support for Stochastic Depth

    Add Support for Stochastic Depth

    This PR adds support for stochastic depth, which is used in the paper for the vision experiments. I went ahead an added it to gMLP as well for completeness.

    I tried my best to match your style. Let me know if there are any problems, or if you want me to refactor anything.

    opened by mlw214 2
  • Don't you think this is more legible?

    Don't you think this is more legible?

    ` class SpatialGatingUnit(nn.Module): def init(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3): super().init() dim_out = dim // 2 self.causal = causal

        self.norm = nn.LayerNorm(dim_out)
        #self.proj = nn.Conv1d(dim_seq, dim_seq, 1)
    
        self.dim_seq = dim_seq
        self.w_ = nn.Parameter(torch.zeros(dim_seq, dim_seq), requires_grad=True)   ####
        self.b_ = nn.Parameter(torch.ones(dim_seq), requires_grad=True)  ####
    
        self.act = act
    
        init_eps /= dim_seq
        #nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
        #nn.init.constant_(self.proj.bias, 1.)
    
    def forward(self, x, gate_res = None): # x -> bsz, len, hidden*6
        device, n = x.device, x.shape[1]
    
        res, gate = x.chunk(2, dim = -1)
        gate = self.norm(gate)
    
        weight, bias = self.w_, self.b_ # weight -> len, len, 1     bias -> len
    
        if self.causal:
            weight.unsqueeze(-1) # TODO
            weight, bias = weight[:n, :n], bias[:n]
            mask = torch.ones(weight.shape[:2], device = device).triu_(1).bool()
            weight = weight.masked_fill(mask[..., None], 0.)
            weight.squeeze(-1)# TODO
    
        gate = torch.matmul(weight, gate) + bias[None, :self.dim_seq, None]   # WZ + b
    
        #gate = F.conv1d(gate, weight, bias)   # WZ + b
    
        if exists(gate_res):
            gate = gate + gate_res
    
        return self.act(gate) * res
    

    `

    opened by ZIZUN 0
  • Potentially missing the high way pass

    Potentially missing the high way pass

    Hello,

    Maybe I missed it, but would you mind pointing out where the high way pass of the gMLP block is in the code? Based on the paper, there is a high way path (addition) between the input and the output. I couldn't find it in the gMLPBlock code.

    Thank you

    opened by Vincent-Li-9701 1
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Code for Learning Manifold Patch-Based Representations of Man-Made Shapes, in ICLR 2021.

LearningPatches | Webpage | Paper | Video Learning Manifold Patch-Based Representations of Man-Made Shapes Dmitriy Smirnov, Mikhail Bessmeltsev, Justi

Dima Smirnov 22 Nov 14, 2022
DeepCO3: Deep Instance Co-segmentation by Co-peak Search and Co-saliency

[CVPR19] DeepCO3: Deep Instance Co-segmentation by Co-peak Search and Co-saliency (Oral paper) Authors: Kuang-Jui Hsu, Yen-Yu Lin, Yung-Yu Chuang PDF:

Kuang-Jui Hsu 139 Dec 22, 2022
PyTorch implementation of Glow

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions (https://arxiv.org/abs/1807.03039) Usage: python train.p

Kim Seonghyeon 433 Dec 27, 2022
Tutorial in Python targeted at Epidemiologists. Will discuss the basics of analysis in Python 3

Python-for-Epidemiologists This repository is an introduction to epidemiology analyses in Python. Additionally, the tutorials for my library zEpid are

Paul Zivich 120 Nov 17, 2022
[CVPR 2021] MiVOS - Scribble to Mask module

MiVOS (CVPR 2021) - Scribble To Mask Ho Kei Cheng, Yu-Wing Tai, Chi-Keung Tang [arXiv] [Paper PDF] [Project Page] A simplistic network that turns scri

Rex Cheng 65 Dec 22, 2022
A PyTorch Implementation of PGL-SUM from "Combining Global and Local Attention with Positional Encoding for Video Summarization", Proc. IEEE ISM 2021

PGL-SUM: Combining Global and Local Attention with Positional Encoding for Video Summarization PyTorch Implementation of PGL-SUM From "PGL-SUM: Combin

Evlampios Apostolidis 35 Dec 22, 2022
Vehicle Detection Using Deep Learning and YOLO Algorithm

VehicleDetection Vehicle Detection Using Deep Learning and YOLO Algorithm Dataset take or find vehicle images for create a special dataset for fine-tu

Maryam Boneh 96 Jan 05, 2023
HiddenMarkovModel implements hidden Markov models with Gaussian mixtures as distributions on top of TensorFlow

Class HiddenMarkovModel HiddenMarkovModel implements hidden Markov models with Gaussian mixtures as distributions on top of TensorFlow 2.0 Installatio

Susara Thenuwara 2 Nov 03, 2021
This repository includes code of my study about Asynchronous in Frequency domain of GAN images.

Exploring the Asynchronous of the Frequency Spectra of GAN-generated Facial Images Binh M. Le & Simon S. Woo, "Exploring the Asynchronous of the Frequ

4 Aug 06, 2022
WSDM2022 "A Simple but Effective Bidirectional Extraction Framework for Relational Triple Extraction"

BiRTE WSDM2022 "A Simple but Effective Bidirectional Extraction Framework for Relational Triple Extraction" Requirements The main requirements are: py

9 Dec 27, 2022
ECCV18 Workshops - Enhanced SRGAN. Champion PIRM Challenge on Perceptual Super-Resolution. The training codes are in BasicSR.

ESRGAN (Enhanced SRGAN) [ 🚀 BasicSR] [Real-ESRGAN] ✨ New Updates. We have extended ESRGAN to Real-ESRGAN, which is a more practical algorithm for rea

Xintao 4.7k Jan 02, 2023
A project to make Amazon Echo respond to sign language using your webcam

Making Alexa respond to Sign Language using Tensorflow.js Try the live demo Read the Blog Post on Tensorflow's Blog Coming Soon Watch the video This p

Abhishek Singh 444 Jan 03, 2023
Library for 8-bit optimizers and quantization routines.

bitsandbytes Bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers and quantization functions. Paper -- V

Facebook Research 687 Jan 04, 2023
A 1.3B text-to-image generation model trained on 14 million image-text pairs

minDALL-E on Conceptual Captions minDALL-E, named after minGPT, is a 1.3B text-to-image generation model trained on 14 million image-text pairs for no

Kakao Brain 604 Dec 14, 2022
Official implementation of NeurIPS'21: Implicit SVD for Graph Representation Learning

isvd Official implementation of NeurIPS'21: Implicit SVD for Graph Representation Learning If you find this code useful, you may cite us as: @inprocee

Sami Abu-El-Haija 16 Jan 08, 2023
3D Avatar Lip Syncronization from speech (JALI based face-rigging)

visemenet-inference Inference Demo of "VisemeNet-tensorflow" VisemeNet is an audio-driven animator centric speech animation driving a JALI or standard

Junhwan Jang 17 Dec 20, 2022
Learning Efficient Online 3D Bin Packing on Packing Configuration Trees

Learning Efficient Online 3D Bin Packing on Packing Configuration Trees This repository is being continuously updated, please stay tuned! Any code con

86 Dec 28, 2022
[ICML 2020] "When Does Self-Supervision Help Graph Convolutional Networks?" by Yuning You, Tianlong Chen, Zhangyang Wang, Yang Shen

When Does Self-Supervision Help Graph Convolutional Networks? PyTorch implementation for When Does Self-Supervision Help Graph Convolutional Networks?

Shen Lab at Texas A&M University 106 Nov 11, 2022
An executor that loads ONNX models and embeds documents using the ONNX runtime.

ONNXEncoder An executor that loads ONNX models and embeds documents using the ONNX runtime. Usage via Docker image (recommended) from jina import Flow

Jina AI 2 Mar 15, 2022
Open-sourcing the Slates Dataset for recommender systems research

FINN.no Recommender Systems Slate Dataset This repository accompany the paper "Dynamic Slate Recommendation with Gated Recurrent Units and Thompson Sa

FINN.no 48 Nov 28, 2022