Implementation of self-attention mechanisms for general purpose. Focused on computer vision modules. Ongoing repository.

Overview

Self-attention building blocks for computer vision applications in PyTorch

Implementation of self attention mechanisms for computer vision in PyTorch with einsum and einops. Focused on computer vision self-attention modules.

Install it via pip

It would be nice to install pytorch in your enviroment, in case you don't have a GPU.

pip install self-attention-cv

Related articles

More articles are on the way.

Code Examples

Multi-head attention

import torch
from self_attention_cv import MultiHeadSelfAttention

model = MultiHeadSelfAttention(dim=64)
x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
mask = torch.zeros(10, 10)  # tokens X tokens
mask[5:8, 5:8] = 1
y = model(x, mask)

Axial attention

import torch
from self_attention_cv import AxialAttentionBlock
model = AxialAttentionBlock(in_channels=256, dim=64, heads=8)
x = torch.rand(1, 256, 64, 64)  # [batch, tokens, dim, dim]
y = model(x)

Vanilla Transformer Encoder

import torch
from self_attention_cv import TransformerEncoder
model = TransformerEncoder(dim=64,blocks=6,heads=8)
x = torch.rand(16, 10, 64)  # [batch, tokens, dim]
mask = torch.zeros(10, 10)  # tokens X tokens
mask[5:8, 5:8] = 1
y = model(x,mask)

Vision Transformer with/without ResNet50 backbone for image classification

import torch
from self_attention_cv import ViT, ResNet50ViT

model1 = ResNet50ViT(img_dim=128, pretrained_resnet=False, 
                        blocks=6, num_classes=10, 
                        dim_linear_block=256, dim=256)
# or
model2 = ViT(img_dim=256, in_channels=3, patch_dim=16, num_classes=10,dim=512)
x = torch.rand(2, 3, 256, 256)
y = model2(x) # [2,10]

A re-implementation of Unet with the Vision Transformer encoder

import torch
from self_attention_cv.transunet import TransUnet
a = torch.rand(2, 3, 128, 128)
model = TransUnet(in_channels=3, img_dim=128, vit_blocks=8,
vit_dim_linear_mhsa_block=512, classes=5)
y = model(a) # [2, 5, 128, 128]

Bottleneck Attention block

import torch
from self_attention_cv.bottleneck_transformer import BottleneckBlock
inp = torch.rand(1, 512, 32, 32)
bottleneck_block = BottleneckBlock(in_channels=512, fmap_size=(32, 32), heads=4, out_channels=1024, pooling=True)
y = bottleneck_block(inp)

Position embeddings are also available

1D Positional Embeddings

import torch
from self_attention_cv.pos_embeddings import AbsPosEmb1D,RelPosEmb1D

model = AbsPosEmb1D(tokens=20, dim_head=64)
# batch heads tokens dim_head
q = torch.rand(2, 3, 20, 64)
y1 = model(q)

model = RelPosEmb1D(tokens=20, dim_head=64, heads=3)
q = torch.rand(2, 3, 20, 64)
y2 = model(q)

2D Positional Embeddings

import torch
from self_attention_cv.pos_embeddings import RelPosEmb2D
dim = 32  # spatial dim of the feat map
model = RelPosEmb2D(
    feat_map_size=(dim, dim),
    dim_head=128)

q = torch.rand(2, 4, dim*dim, 128)
y = model(q)

References

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762.
  2. Wang, H., Zhu, Y., Green, B., Adam, H., Yuille, A., & Chen, L. C. (2020, August). Axial-deeplab: Stand-alone axial-attention for panoptic segmentation. In European Conference on Computer Vision (pp. 108-126). Springer, Cham.
  3. Srinivas, A., Lin, T. Y., Parmar, N., Shlens, J., Abbeel, P., & Vaswani, A. (2021). Bottleneck Transformers for Visual Recognition. arXiv preprint arXiv:2101.11605.
  4. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.
Comments
  • Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that? Could you please help me solve it? Thank you

    Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that? Could you please help me solve it? Thank you

    Thank you very much for the code. But when I run test_TransUnet.py , It starts reporting errors. Why is that?I `Traceback (most recent call last): File "self-attention-cv/tests/test_TransUnet.py", line 14, in test_TransUnet() File "/self-attention-cv/tests/test_TransUnet.py", line 11, in test_TransUnet y = model(a) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "self-attention-cv\self_attention_cv\transunet\trans_unet.py", line 88, in forward y = self.project_patches_back(y) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward return F.linear(input, self.weight, self.bias) File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\functional.py", line 1692, in linear output = input.matmul(weight.t()) RuntimeError: mat1 dim 1 must match mat2 dim 0

    Process finished with exit code 1 ` Could you please help me solve it? Thank you.

    opened by yezhengjie 7
  • TransUNet - Why is the patch_dim set to 1?

    TransUNet - Why is the patch_dim set to 1?

    Hi,

    Can you please explain why is the patch_dim set to 1 in TransUNet class? Thank you in advance!

    https://github.com/The-AI-Summer/self-attention-cv/blob/8280009366b633921342db6cab08da17b46fdf1c/self_attention_cv/transunet/trans_unet.py#L54

    opened by dsitnik 7
  • Question: Sliding Window Module for Transformer3dSeg Object

    Question: Sliding Window Module for Transformer3dSeg Object

    I was wondering whether or not you've implemented an example using the network in a 3d medical segmentation task and/or use case? If this network only exports the center slice of a patch then we would need a wrapper function to iterate through all patches in an image to get the final prediction for the entire volume. From the original paper, I assume they choose 10 patches at random from an image during training, but it's not too clear how they pieced everything together during testing.

    Your thoughts on this would be greatly appreciated!

    See: https://github.com/The-AI-Summer/self-attention-cv/blob/33ddf020d2d9fb9c4a4a3b9938383dc9b7405d8c/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L10

    opened by jmarsil 5
  • ResNet + Pyramid Vision Transformer Version 2

    ResNet + Pyramid Vision Transformer Version 2

    Thank you for your work with a clear explanation. As you know, ViT doesn't work on small datasets and I am implementing ResNet34 with Pyramid Vision Transformer Version 2 to make it better. The architecture of ViT and PVT V2 is completely different. Could you provide me some help to implement it? please

    opened by khawar-islam 3
  • Request for Including UNETR

    Request for Including UNETR

    Thanks for great work ! I noticed nice implementation of this paper (https://arxiv.org/abs/2103.10504) here:

    https://github.com/tamasino52/UNETR/blob/main/unetr.py

    It would be great if this can also be included in your repo, since it comes with lots of other great features. So we can explore more.

    Thanks ~

    opened by Siyuan89 3
  • ImageNet Pretrained TimesFormer

    ImageNet Pretrained TimesFormer

    I see you have recently added the TimesFormer model to this repository. In the paper, they initialize their model weights from ImageNet pretrained weights of ViT. Does your implementation offer this too? Thanks!

    opened by RaivoKoot 3
  • Do the encoder modules incorporate positional encoding?

    Do the encoder modules incorporate positional encoding?

    I am wondering if I use say the LinformerEncoder if I have to add the position encoding or if that's already done? From the source files it doesn't seem to be there, but I'm not sure how to include the position encoding as they seem to need the query which isn't available when just passing data directly to the LinformerEncoder. I very well may be missing something any help would be great. Perhaps an example using positional encoding would be good.

    opened by jfkback 3
  • use AxialAttention on gpu

    use AxialAttention on gpu

    I try to use AxialAttention on gpu, but I get a mistake.Can you give me some tips about using AxialAttention on gpu. Thanks! mistake: RuntimeError: expected self and mask to be on the same device, but got mask on cpu and self on cuda:0

    opened by Iverson-Al 2
  • Axial attention

    Axial attention

    What is the meaning of qkv_channels? https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/axial_attention_deeplab/axial_attention.py#L32

    opened by Jayden9912 1
  • Convolution-Free Medical Image Segmentation using Transformers

    Convolution-Free Medical Image Segmentation using Transformers

    Thank you very much for your contribution. As a novice, I have a doubt. In tranf3dseg, the output of the model is the prediction segmentation of the center patch, so how can I get the segmentation of the whole input image? I am looking forward to any reply.

    opened by WinsaW 1
  • Regression with attention

    Regression with attention

    Hello!

    thanks for sharing this nice repo :)

    I'm trying to use ViT to do regression on images. I'd like to predict 6 floats per image.

    My understanding is that I'd need to simply define the network as

    vit = ViT(img_dim=128,
                   in_channels=3,
                   patch_dim=16,
                   num_classes=6,
                   dim=512)
    

    and during training call

    vit(x)
    

    and compute the loss as MSE instead of CE.

    The network actually runs but it doesn't seem to converge. Is there something obvious I am missing?

    many thanks!

    opened by alemelis 1
  • Segmentation for full image

    Segmentation for full image

    Hi,

    Thank you for your effort and time in implementing this. I have a quick question, I want to get segmentation for full image not just for the middle token, would it be correct to change self.tokens to self.p here:

    https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L66

    and change this:

    https://github.com/The-AI-Summer/self-attention-cv/blob/5246e550ecb674f60df76a6c1011fde30ded7f44/self_attention_cv/Transformer3Dsegmentation/tranf3Dseg.py#L94

    to

    y = self.mlp_seg_head(y)

    opened by aqibsaeed 0
Releases(1.2.3)
Owner
AI Summer
Learn Deep Learning and Artificial Intelligence
AI Summer
A Python library for working with arbitrary-dimension hypercomplex numbers following the Cayley-Dickson construction of algebras.

Hypercomplex A Python library for working with quaternions, octonions, sedenions, and beyond following the Cayley-Dickson construction of hypercomplex

7 Nov 04, 2022
PyTorch Code for the paper "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives"

Improving Visual-Semantic Embeddings with Hard Negatives Code for the image-caption retrieval methods from VSE++: Improving Visual-Semantic Embeddings

Fartash Faghri 441 Dec 05, 2022
Official PyTorch implementation of Synergies Between Affordance and Geometry: 6-DoF Grasp Detection via Implicit Representations

Synergies Between Affordance and Geometry: 6-DoF Grasp Detection via Implicit Representations Zhenyu Jiang, Yifeng Zhu, Maxwell Svetlik, Kuan Fang, Yu

UT-Austin Robot Perception and Learning Lab 63 Jan 03, 2023
Multi Camera Calibration

Multi Camera Calibration 'modules/camera_calibration/app/camera_calibration.cpp' is for calculating extrinsic parameter of each individual cameras. 'm

7 Dec 01, 2022
This repository contains the code for using the H3DS dataset introduced in H3D-Net: Few-Shot High-Fidelity 3D Head Reconstruction

H3DS Dataset This repository contains the code for using the H3DS dataset introduced in H3D-Net: Few-Shot High-Fidelity 3D Head Reconstruction Access

Crisalix 72 Dec 10, 2022
Grad2Task: Improved Few-shot Text Classification Using Gradients for Task Representation

Grad2Task: Improved Few-shot Text Classification Using Gradients for Task Representation Prerequisites This repo is built upon a local copy of transfo

Jixuan Wang 10 Sep 28, 2022
PyTorch implementation of Spiking Neural Networks trained on surrogate gradient & BPTT using snntorch.

snn-localization repo PyTorch implementation of Spiking Neural Networks trained on surrogate gradient & BPTT using snntorch. Install Dependencies Orig

Sami BARCHID 1 Jan 06, 2022
This repository contains code for the paper "Decoupling Representation and Classifier for Long-Tailed Recognition", published at ICLR 2020

Classifier-Balancing This repository contains code for the paper: Decoupling Representation and Classifier for Long-Tailed Recognition Bingyi Kang, Sa

Facebook Research 820 Dec 26, 2022
Unofficial PyTorch implementation of TokenLearner by Google AI

tokenlearner-pytorch Unofficial PyTorch implementation of TokenLearner by Ryoo et al. from Google AI (abs, pdf) Installation You can install TokenLear

Rishabh Anand 46 Dec 20, 2022
Unsupervised Feature Ranking via Attribute Networks.

FRANe Unsupervised Feature Ranking via Attribute Networks (FRANe) converts a dataset into a network (graph) with nodes that correspond to the features

7 Sep 29, 2022
“袋鼯麻麻——智能购物平台”能够精准地定位识别每一个商品

“袋鼯麻麻——智能购物平台”能够精准地定位识别每一个商品,并且能够返回完整地购物清单及顾客应付的实际商品总价格,极大地降低零售行业实际运营过程中巨大的人力成本,提升零售行业无人化、自动化、智能化水平。

thomas-yanxin 192 Jan 05, 2023
A tool to analyze leveraged liquidity mining and find optimal option combination for hedging.

LP-Option-Hedging Description A Python program to analyze leveraged liquidity farming/mining and find the optimal option combination for hedging imper

Aureliano 18 Dec 19, 2022
PyTorch code for the NAACL 2021 paper "Improving Generation and Evaluation of Visual Stories via Semantic Consistency"

Improving Generation and Evaluation of Visual Stories via Semantic Consistency PyTorch code for the NAACL 2021 paper "Improving Generation and Evaluat

Adyasha Maharana 28 Dec 08, 2022
SphereFace: Deep Hypersphere Embedding for Face Recognition

SphereFace: Deep Hypersphere Embedding for Face Recognition By Weiyang Liu, Yandong Wen, Zhiding Yu, Ming Li, Bhiksha Raj and Le Song License SphereFa

Weiyang Liu 1.5k Dec 29, 2022
text_recognition_toolbox: The reimplementation of a series of classical scene text recognition papers with Pytorch in a uniform way.

text recognition toolbox 1. 项目介绍 该项目是基于pytorch深度学习框架,以统一的改写方式实现了以下6篇经典的文字识别论文,论文的详情如下。该项目会持续进行更新,欢迎大家提出问题以及对代码进行贡献。 模型 论文标题 发表年份 模型方法划分 CRNN 《An End-t

168 Dec 24, 2022
This repository contains code used to audit the stability of personality predictions made by two algorithmic hiring systems

Stability Audit This repository contains code used to audit the stability of personality predictions made by two algorithmic hiring systems, Humantic

Data, Responsibly 4 Oct 27, 2022
Classic Papers for Beginners and Impact Scope for Authors.

There have been billions of academic papers around the world. However, maybe only 0.0...01% among them are valuable or are worth reading. Since our limited life has never been forever, TopPaper provi

Qiulin Zhang 228 Dec 18, 2022
UAV-Networks-Routing is a Python simulator for experimenting routing algorithms and mac protocols on unmanned aerial vehicle networks.

UAV-Networks Simulator - Autonomous Networking - A.A. 20/21 UAV-Networks-Routing is a Python simulator for experimenting routing algorithms and mac pr

0 Nov 13, 2021
TensorFlow implementation of "Variational Inference with Normalizing Flows"

[TensorFlow 2] Variational Inference with Normalizing Flows TensorFlow implementation of "Variational Inference with Normalizing Flows" [1] Concept Co

YeongHyeon Park 7 Jun 08, 2022
Using deep learning model to detect breast cancer.

Breast-Cancer-Detection Breast cancer is the most frequent cancer among women, with around one in every 19 women at risk. The number of cases of breas

1 Feb 13, 2022