Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Overview

Transformer in Transformer

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch.

Install

$ pip install transformer-in-transformer

Usage

import torch
from transformer_in_transformer import TNT

tnt = TNT(
    image_size = 256,       # size of image
    patch_dim = 512,        # dimension of patch token
    pixel_dim = 24,         # dimension of pixel token
    patch_size = 16,        # patch size
    pixel_size = 4,         # pixel size
    depth = 6,              # depth
    num_classes = 1000,     # output number of classes
    attn_dropout = 0.1,     # attention dropout
    ff_dropout = 0.1        # feedforward dropout
)

img = torch.randn(2, 3, 256, 256)
logits = tnt(img) # (2, 1000)

Citations

@misc{han2021transformer,
    title   = {Transformer in Transformer}, 
    author  = {Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
    year    = {2021},
    eprint  = {2103.00112},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • Only works if pixel_size**2 == patch_size?

    Only works if pixel_size**2 == patch_size?

    Hi, is this only supposed to work if

    pixel_size**2 == patch_size 
    

    ?. When setting the patch_size to any number that doesn't fulfill the equation this error occurs:

    --> 146         pixels += rearrange(self.pixel_pos_emb, 'n d -> () n d')
        147 
        148         for pixel_attn, pixel_ff, pixel_to_patch_residual, patch_attn, patch_ff in self.layers:
    
    RuntimeError: The size of tensor a (4) must match the size of tensor b (64) at non-singleton dimension 1
    

    The error came when running

    tnt = TNT(
        image_size = 128,       # size of image
        patch_dim = 256,        # dimension of patch token
        pixel_dim = 24,         # dimension of pixel token
        patch_size = 16,        # patch size
        pixel_size = 2,         # pixel size
        depth = 6,              # depth
        heads = 1,
        num_classes = 2,     # output number of classes
        attn_dropout = 0.1,     # attention dropout
        ff_dropout = 0.1        # feedforward dropout,
    )
    img = torch.randn(2, 3, 128, 128)
    logits = tnt(img)
    

    Since I am completely new to einops its quite hard for me to debug :D Thanks

    opened by PhilippMarquardt 1
  • Not sure what is wrong!

    Not sure what is wrong!


    RuntimeError Traceback (most recent call last) in 14 15 img = torch.randn(1, 3, 256, 256) ---> 16 logits = tnt(img) # (2, 1000)

    ~/opt/anaconda3/envs/ml/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1109 or _global_forward_hooks or _global_forward_pre_hooks): -> 1110 return forward_call(*input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], []

    ~/opt/anaconda3/envs/ml/lib/python3.8/site-packages/transformer_in_transformer/tnt.py in forward(self, x) 159 patches = repeat(self.patch_tokens[:(n + 1)], 'n d -> b n d', b = b) 160 --> 161 patches += rearrange(self.patch_pos_emb[:(n + 1)], 'n d -> () n d') 162 pixels += rearrange(self.pixel_pos_emb, 'n d -> () n d') 163

    RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

    opened by RisabBiswas 0
  • patch_tokens vs patch_pos_emb

    patch_tokens vs patch_pos_emb

    Hi!

    I'm trying to understand your TNT implementation and one thing that got me a bit confused is why there are 2 parameters patch_tokens and patch_pos_emb which seems to have the same purpose - to encode patch position. Isn't one of them redundant?

    self.patch_tokens = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
    self.patch_pos_emb = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
    ...
    patches = repeat(self.patch_tokens[:(n + 1)], 'n d -> b n d', b = b)
    patches += rearrange(self.patch_pos_emb[:(n + 1)], 'n d -> () n d')
    
    opened by stas-sl 0
  • Inconsistent model  params with MindSpore src code

    Inconsistent model params with MindSpore src code

    There's no function or readme description of TNT-S/TNT-B model in this codebase. Something like :

    def tnt_b(num_class):
        return TNT(img_size=384,
                   patch_size=16,
                   num_channels=3,
                   embedding_dim=640,
                   num_heads=10,
                   num_layers=12,
                   hidden_dim=640*4,
                   stride=4,
                   num_class=num_class)
    

    And heads number of inner block should be 4.... https://github.com/lucidrains/transformer-in-transformer/blob/main/transformer_in_transformer/tnt.py#L135

    Wondering if anyone reproduce the paper reported results with this codebase??

    opened by WongChen 0
  • Why the loss become NaN?

    Why the loss become NaN?

    It is a great project. I am very interested in Transformer in Transformer model. I had use your model to train on Vehicle-1M dataset. Vehicle-1M is a fine graied visual classification dataset. When I use this model the loss become NaN after some batch iteration. I had decrease the learning rate of AdamOptimizer and clipping the graident torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2) . But the loss still will become NaN sometimes. It seems that gradients are not big but they are in the same direction for many iterations. How to solve it?

    opened by yt7589 3
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
This repository contains the code used in the paper "Prompt-Based Multi-Modal Image Segmentation".

Prompt-Based Multi-Modal Image Segmentation This repository contains the code used in the paper "Prompt-Based Multi-Modal Image Segmentation". The sys

Timo Lüddecke 305 Dec 30, 2022
Must-read Papers on Physics-Informed Neural Networks.

PINNpapers Contributed by IDRL lab. Introduction Physics-Informed Neural Network (PINN) has achieved great success in scientific computing since 2017.

IDRL 330 Jan 07, 2023
Azion the best solution of Edge Computing in the world.

Azion Edge Function docker action Create or update an Edge Functions on Azion Edge Nodes. The domain name is the key for decision to a create or updat

8 Jul 16, 2022
Unofficial PyTorch implementation of the Adaptive Convolution architecture for image style transfer

AdaConv Unofficial PyTorch implementation of the Adaptive Convolution architecture for image style transfer from "Adaptive Convolutions for Structure-

65 Dec 22, 2022
Pytorch implementation of CVPR2020 paper “VectorNet: Encoding HD Maps and Agent Dynamics from Vectorized Representation”

VectorNet Re-implementation This is the unofficial pytorch implementation of CVPR2020 paper "VectorNet: Encoding HD Maps and Agent Dynamics from Vecto

120 Jan 06, 2023
Differentiable architecture search for convolutional and recurrent networks

Differentiable Architecture Search Code accompanying the paper DARTS: Differentiable Architecture Search Hanxiao Liu, Karen Simonyan, Yiming Yang. arX

Hanxiao Liu 3.7k Jan 09, 2023
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch

alias-free-gan-pytorch Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) This implementation

Kim Seonghyeon 502 Jan 03, 2023
YOLOX-Paddle - A reproduction of YOLOX by PaddlePaddle

YOLOX-Paddle A reproduction of YOLOX by PaddlePaddle 数据集准备 下载COCO数据集,准备为如下路径 /ho

QuanHao Guo 6 Dec 18, 2022
Code for technical report "An Improved Baseline for Sentence-level Relation Extraction".

RE_improved_baseline Code for technical report "An Improved Baseline for Sentence-level Relation Extraction". Requirements torch = 1.8.1 transformers

Wenxuan Zhou 74 Nov 29, 2022
T-LOAM: Truncated Least Squares Lidar-only Odometry and Mapping in Real-Time

T-LOAM: Truncated Least Squares Lidar-only Odometry and Mapping in Real-Time The first Lidar-only odometry framework with high performance based on tr

Pengwei Zhou 183 Dec 01, 2022
Sequence modeling benchmarks and temporal convolutional networks

Sequence Modeling Benchmarks and Temporal Convolutional Networks (TCN) This repository contains the experiments done in the work An Empirical Evaluati

CMU Locus Lab 3.5k Jan 01, 2023
Adds timm pretrained backbone to pytorch's FasterRcnn model

Operating Systems Lab (ETCS-352) Experiments for Operating Systems Lab (ETCS-352) performed by me in 2021 at uni. All codes are written by me except t

Mriganka Nath 12 Dec 03, 2022
MHFormer: Multi-Hypothesis Transformer for 3D Human Pose Estimation

MHFormer: Multi-Hypothesis Transformer for 3D Human Pose Estimation This repo is the official implementation of "MHFormer: Multi-Hypothesis Transforme

Vegetabird 281 Jan 07, 2023
U-2-Net: U Square Net - Modified for paired image training of style transfer

U2-Net: U Square Net Modified for paired image training of style transfer This is an unofficial repo making use of the code which was made available b

Doron Adler 43 Oct 03, 2022
Frequency Spectrum Augmentation Consistency for Domain Adaptive Object Detection

Frequency Spectrum Augmentation Consistency for Domain Adaptive Object Detection Main requirements torch = 1.0 torchvision = 0.2.0 Python 3 Environm

15 Apr 04, 2022
A data annotation pipeline to generate high-quality, large-scale speech datasets with machine pre-labeling and fully manual auditing.

About This repository provides data and code for the paper: Scalable Data Annotation Pipeline for High-Quality Large Speech Datasets Development (subm

Appen Repos 86 Dec 07, 2022
VolumeGAN - 3D-aware Image Synthesis via Learning Structural and Textural Representations

VolumeGAN - 3D-aware Image Synthesis via Learning Structural and Textural Representations 3D-aware Image Synthesis via Learning Structural and Textura

GenForce: May Generative Force Be with You 116 Dec 26, 2022
ElegantRL is featured with lightweight, efficient and stable, for researchers and practitioners.

Lightweight, efficient and stable implementations of deep reinforcement learning algorithms using PyTorch. 🔥

AI4Finance 2.5k Jan 08, 2023
RRL: Resnet as representation for Reinforcement Learning

Resnet as representation for Reinforcement Learning (RRL) is a simple yet effective approach for training behaviors directly from visual inputs. We demonstrate that features learned by standard image

Meta Research 21 Dec 07, 2022