Implementation of Axial attention - attending to multi-dimensional data efficiently

Overview

Axial Attention

PyPI version

Implementation of Axial attention in Pytorch. A simple but powerful technique to attend to multi-dimensional data efficiently. It has worked wonders for me and many other researchers.

Simply add some positional encoding to your data and pass it into this handy class, specifying which dimension is considered the embedding, and how many axial dimensions to rotate through. All the permutating, reshaping, will be taken care of for you.

This paper was actually rejected on the basis of being too simple. And yet, it has since been used successfully in a number of applications, among those weather prediction, all-attention image segmentation. Just goes to show.

Install

$ pip install axial_attention

Usage

Image

import torch
from axial_attention import AxialAttention

img = torch.randn(1, 3, 256, 256)

attn = AxialAttention(
    dim = 3,               # embedding dimension
    dim_index = 1,         # where is the embedding dimension
    dim_heads = 32,        # dimension of each head. defaults to dim // heads if not supplied
    heads = 1,             # number of heads for multi-head attention
    num_dimensions = 2,    # number of axial dimensions (images is 2, video is 3, or more)
    sum_axial_out = True   # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)

attn(img) # (1, 3, 256, 256)

Channel-last image latents

import torch
from axial_attention import AxialAttention

img = torch.randn(1, 20, 20, 512)

attn = AxialAttention(
    dim = 512,           # embedding dimension
    dim_index = -1,      # where is the embedding dimension
    heads = 8,           # number of heads for multi-head attention
    num_dimensions = 2,  # number of axial dimensions (images is 2, video is 3, or more)
)

attn(img) # (1, 20, 20 ,512)

Video

import torch
from axial_attention import AxialAttention

video = torch.randn(1, 5, 128, 256, 256)

attn = AxialAttention(
    dim = 128,           # embedding dimension
    dim_index = 2,       # where is the embedding dimension
    heads = 8,           # number of heads for multi-head attention
    num_dimensions = 3,  # number of axial dimensions (images is 2, video is 3, or more)
)

attn(video) # (1, 5, 128, 256, 256)

Image Transformer, with reversible network

import torch
from torch import nn
from axial_attention import AxialImageTransformer

conv1x1 = nn.Conv2d(3, 128, 1)

transformer = AxialImageTransformer(
    dim = 128,
    depth = 12,
    reversible = True
)

img = torch.randn(1, 3, 512, 512)

transformer(conv1x1(img)) # (1, 3, 512, 512)

With axial positional embedding

import torch
from axial_attention import AxialAttention, AxialPositionalEmbedding

img = torch.randn(1, 512, 20, 20)

attn = AxialAttention(
    dim = 512,
    heads = 8,
    dim_index = 1
)

pos_emb = AxialPositionalEmbedding(
    dim = 512,
    shape = (20, 20)
)

img = pos_emb(img)  # (1, 512, 20, 20)  - now positionally embedded
img = attn(img)     # (1, 512, 20, 20)

Citation

@misc{ho2019axial,
    title  = {Axial Attention in Multidimensional Transformers},
    author = {Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans},
    year   = {2019},
    archivePrefix = {arXiv}
}
@misc{wang2020axialdeeplab,
    title   = {Axial-DeepLab: Stand-Alone Axial-Attention for Panoptic Segmentation},
    author  = {Huiyu Wang and Yukun Zhu and Bradley Green and Hartwig Adam and Alan Yuille and Liang-Chieh Chen},
    year    = {2020},
    eprint  = {2003.07853},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{huang2019ccnet,
    title   = {Ccnet: Criss-cross attention for semantic segmentation},
    author  = {Huang, Zilong and Wang, Xinggang and Huang, Lichao and Huang, Chang and Wei, Yunchao and Liu, Wenyu},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision},
    pages   = {603--612},
    year    = {2019}
}
Comments
  • Reimplementation of image modeling results in AXIAL ATTENTION IN MULTIDIMENSIONAL TRANSFORMERS.

    Reimplementation of image modeling results in AXIAL ATTENTION IN MULTIDIMENSIONAL TRANSFORMERS.

    Hi, this is a nice paper. How can I use your shared code to reimplement the image modeling task on ImageNet 32x32?

    Thanks. Looking forward to your reply.

    opened by liujiaheng 3
  • AxialPositionalEmbedding

    AxialPositionalEmbedding

    Would you be able to provide an example of how to add the positional encoding with the AxialPositionalEmbedding class or explain what the emb_dim, emb_dim_index, and dimensions arguments are specifically? Thanks for the repo!

    opened by dansola 2
  • Problem of ParameterList with nn.DataParallel

    Problem of ParameterList with nn.DataParallel

    https://github.com/lucidrains/axial-attention/blob/a1a483c0f4a3922eef8f9a857dc1a802523bd437/axial_attention/axial_attention.py#L100

    This line would lead to the following issue: "UserWarning: nn.ParameterList is being used with DataParallel but this is not supported. This list will appear empty for the models replicated on each GPU except the original one."

    It is a known issue here

    The simple solution should be to store the Parameters directly on the Module.

    class AxialPositionalEmbedding(nn.Module):
        def __init__(self, dim, shape, emb_dim_index = 1):
            super().__init__()
            parameters = []
            total_dimensions = len(shape) + 2
            ax_dim_indexes = [i for i in range(1, total_dimensions) if i != emb_dim_index]
            
            for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indexes)):
                shape = [1] * total_dimensions
                shape[emb_dim_index] = dim
                shape[axial_dim_index] = axial_dim
                parameter = nn.Parameter(torch.randn(*shape))
                setattr(self, f'param_{i}', parameter)
                setattr(self, f'param_num', i+1)
    
        def forward(self, x):
            for i in range(self.param_num):
                x = x + getattr(self, f'param_{i}')
            return x
    
    opened by resuly 1
  • Positional embeddings for different image sizes

    Positional embeddings for different image sizes

    Hi, once again thanks for your great work! Since I want to use the axial attention with positional embedding for unknown image sizes (But I know the max size), I was wondering if you think that changing https://github.com/lucidrains/axial-attention/blob/master/axial_attention/axial_attention.py#L104 to

    for cnt, param in enumerate(self.params):
        x = x + param[([slice(None)] * (cnt + 2) + [slice(x.shape[cnt + 2])])]
    

    does the right thing. I can now do this

    v = AxialImageTransformer(64, depth = 1, axial_pos_emb_shape = (64,64), dim_index = 1)       
    t1 = torch.randn(2, 64, 17, 16)
    t2 = torch.randn(2, 64, 13, 18)
    t3 = torch.randn(2, 64, 64, 64)
    print(v(t1).shape)
    print(v(t2).shape)
    print(v(t3).shape)
    Output:
    torch.Size([2, 64, 17, 16])
    torch.Size([2, 64, 13, 18])
    torch.Size([2, 64, 64, 64])
    

    I think that makes it easier to integrate it in fully convolutional nets for multi scale training.

    opened by PhilippMarquardt 1
  • User Warning: Mixed memory format inputs detected

    User Warning: Mixed memory format inputs detected

    At site-packages/axial_attention/axial_attention.py:176: UserWarning: Mixed memory format inputs detected while calling the operator. The operator will output contiguous tensor even if some of the inputs are in channels_last format. ( Triggered internally at /opt/conda/conda-bld/pytorch_1595629427286/work/aten/src/ATen/native/TensorIterator.cpp:918.) return sum(map(lambda axial_attn: axial_attn(x), self.axial_attentions))

    I am using latest axial_attention (v0.4) and Pytorch 1.6.0

    Code:

    import torch
    from axial_attention import AxialAttention
    
    img = torch.randn(1, 24, 64, 64)
    
    attn = AxialAttention(
        dim = 24,               # embedding dimension
        dim_index = 1,         # where is the embedding dimension
        dim_heads = 32,        # dimension of each head. defaults to dim // heads if not supplied
        heads = 8,             # number of heads for multi-head attention
        num_dimensions = 2,    # number of axial dimensions (images is 2, video is 3, or more)
        sum_axial_out = True   # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
    )
    
    out= attn(img) 
    
    

    Will it affect trainings and inference?

    opened by lokeshkvn 1
  • Examples for image sequence/video

    Examples for image sequence/video

    Hello, Do you have examples of integrating this on image sequences? I am trying to get rid of ConvLSTM's for encoding sequence of images and AxialAttention may be a good starting point. Do you have an exmaple/notebook that I could look to integrate this on my type of data? Thank you for this amazing work. Thomas

    opened by tcapelle 1
  • Ask a question

    Ask a question

    I'm interested to your excellent work,but I'm new to pytorch,can I ask a question where is the start position in the code that i will understand whole project from it ?Thx for your reply

    opened by meiguoofa 0
  • Hi, I have a problem

    Hi, I have a problem

    import torch from axial_attention import AxialAttention

    img = torch.randn(1, 3, 256, 256)

    attn = AxialAttention( dim = 3, # embedding dimension dim_index = 1, # where is the embedding dimension dim_heads = 32, # dimension of each head. defaults to dim // heads if not supplied heads = 1, # number of heads for multi-head attention num_dimensions = 2, # number of axial dimensions (images is 2, video is 3, or more) sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true )

    attn(img) # (1, 3, 256, 256)

    Thanks for your great project, I want to ask if my image is one channel image will influence the num_dimensions value?

    opened by meiguoofa 0
  • Extracting attention maps

    Extracting attention maps

    Hi there,

    Excellent project!

    I'm using axial-attention with video (1, 5, 128, 256, 256) and sum_axial_out=True, and I wish to visualise the attention maps.

    Essentially, given my video, and two frame indices frame_a_idx and frame_b_idx, I need to extract the attention map over frame_b to a chosen pixel (x, y) in frame_a (after the axial sum).

    My understanding is that I should be able to reshape the dots (after softmax) according to the permutations in calculate_permutations, then sum these permuted dots together to form a final attention score tensor of an accessible shape, thus ready for visualisation.

    I am slightly stuck due to the numerous axial permutations and shape mismatches. What I am doing is as follows:

    In SelfAttention.forward():

    dots_reshaped = dots.reshape(b, h, t, t)
    return out, dots_reshaped
    

    In PermuteToFrom.forward():

     # attention
    axial, dots = self.fn(axial, **kwargs)
    
    # restore to original shape and permutation
    axial = axial.reshape(*shape)
    axial = axial.permute(*self.inv_permutation).contiguous()
    dots = dots.reshape(*shape[:3], *dots.shape[1:])
    

    However, I am unsure of how to un-permute the dots appropriately such that all resulting “axes” (of different sizes) can be summed. If you have suggestions or code for doing so, it would be very much appreciated, thanks!

    opened by vibrant-galaxy 3
Releases(0.6.1)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
A series of convenience functions to make basic image processing operations such as translation, rotation, resizing, skeletonization, and displaying Matplotlib images easier with OpenCV and Python.

imutils A series of convenience functions to make basic image processing functions such as translation, rotation, resizing, skeletonization, and displ

Adrian Rosebrock 4.3k Jan 08, 2023
PyTorch code for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

Salesforce 1.3k Dec 31, 2022
A pytorch implementation of the ACL2019 paper "Simple and Effective Text Matching with Richer Alignment Features".

RE2 This is a pytorch implementation of the ACL 2019 paper "Simple and Effective Text Matching with Richer Alignment Features". The original Tensorflo

287 Dec 21, 2022
Source code for EquiDock: Independent SE(3)-Equivariant Models for End-to-End Rigid Protein Docking (ICLR 2022)

Source code for EquiDock: Independent SE(3)-Equivariant Models for End-to-End Rigid Protein Docking (ICLR 2022) Please cite "Independent SE(3)-Equivar

Octavian Ganea 154 Jan 02, 2023
[arXiv] What-If Motion Prediction for Autonomous Driving ❓🚗💨

WIMP - What If Motion Predictor Reference PyTorch Implementation for What If Motion Prediction [PDF] [Dynamic Visualizations] Setup Requirements The W

William Qi 96 Dec 29, 2022
The Unreasonable Effectiveness of Random Pruning: Return of the Most Naive Baseline for Sparse Training

[ICLR 2022] The Unreasonable Effectiveness of Random Pruning: Return of the Most Naive Baseline for Sparse Training The Unreasonable Effectiveness of

VITA 44 Dec 23, 2022
Predict stock movement with Machine Learning and Deep Learning algorithms

Project Overview Stock market movement prediction using LSTM Deep Neural Networks and machine learning algorithms Software and Library Requirements Th

Naz Delam 46 Sep 13, 2022
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 359 Jan 05, 2023
[CVPR 2022 Oral] Rethinking Minimal Sufficient Representation in Contrastive Learning

Rethinking Minimal Sufficient Representation in Contrastive Learning PyTorch implementation of Rethinking Minimal Sufficient Representation in Contras

36 Nov 23, 2022
A PyTorch implementation of the Relational Graph Convolutional Network (RGCN).

Torch-RGCN Torch-RGCN is a PyTorch implementation of the RGCN, originally proposed by Schlichtkrull et al. in Modeling Relational Data with Graph Conv

Thiviyan Singam 66 Nov 30, 2022
DeepHawkeye is a library to detect unusual patterns in images using features from pretrained neural networks

English | 简体中文 Introduction DeepHawkeye is a library to detect unusual patterns in images using features from pretrained neural networks Reference Pat

CV Newbie 28 Dec 13, 2022
Optimal Adaptive Allocation using Deep Reinforcement Learning in a Dose-Response Study

Optimal Adaptive Allocation using Deep Reinforcement Learning in a Dose-Response Study Supplementary Materials for Kentaro Matsuura, Junya Honda, Imad

Kentaro Matsuura 4 Nov 01, 2022
Self-Learned Video Rain Streak Removal: When Cyclic Consistency Meets Temporal Correspondence

In this paper, we address the problem of rain streaks removal in video by developing a self-learned rain streak removal method, which does not require any clean groundtruth images in the training pro

Yang Wenhan 44 Dec 06, 2022
Unofficial implementation of Google's FNet: Mixing Tokens with Fourier Transforms

FNet: Mixing Tokens with Fourier Transforms Pytorch implementation of Fnet : Mixing Tokens with Fourier Transforms. Citation: @misc{leethorp2021fnet,

Rishikesh (ऋषिकेश) 218 Jan 05, 2023
A Graph Neural Network Tool for Recovering Dense Sub-graphs in Random Dense Graphs.

PYGON A Graph Neural Network Tool for Recovering Dense Sub-graphs in Random Dense Graphs. Installation This code requires to install and run the graph

Yoram Louzoun's Lab 0 Jun 25, 2021
Experiments and code to generate the GINC small-scale in-context learning dataset from "An Explanation for In-context Learning as Implicit Bayesian Inference"

GINC small-scale in-context learning dataset GINC (Generative In-Context learning Dataset) is a small-scale synthetic dataset for studying in-context

P-Lambda 29 Dec 19, 2022
Fuwa-http - The http client implementation for the fuwa eco-system

Fuwa HTTP The HTTP client implementation for the fuwa eco-system Example import

Fuwa 2 Feb 16, 2022
Papers about explainability of GNNs

Papers about explainability of GNNs

Dongsheng Luo 236 Jan 04, 2023
Pytorch implementation of YOLOX、PPYOLO、PPYOLOv2、FCOS an so on.

简体中文 | English miemiedetection 概述 miemiedetection是女装大佬咩酱基于YOLOX进行二次开发的个人检测库(使用的深度学习框架为pytorch),支持Windows、Linux系统,以女装大佬咩酱的名字命名。miemiedetection是一个不需要安装的

248 Jan 02, 2023
A repository built on the Flow software package to explore cyber-security attacks on intelligent transportation systems.

A repository built on the Flow software package to explore cyber-security attacks on intelligent transportation systems.

George Gunter 4 Nov 14, 2022