An Implementation of Transformer in Transformer in TensorFlow for image classification, attention inside local patches

Overview

Transformer-in-Transformer Twitter

PyPI Open In Colab Upload Python Package Lint Code Base Code style: black

GitHub License GitHub stars GitHub followers Twitter Follow

An Implementation of the Transformer in Transformer paper by Han et al. for image classification, attention inside local patches. Transformer in Transformer uses pixel level attention paired with patch level attention for image classification, in TensorFlow.

PyTorch Implementation

Installation

Run the following to install:

pip install tnt-tensorflow

Developing tnt-tensorflow

To install tnt-tensorflow, along with tools you need to develop and test, run the following in your virtualenv:

git clone https://github.com/Rishit-dagli/Transformer-in-Transformer.git
# or clone your own fork

cd tnt
pip install -e .[dev]

Usage

import tensorflow as tf
from tnt import TNT

tnt = TNT(
    image_size=256,  # size of image
    patch_dim=512,  # dimension of patch token
    pixel_dim=24,  # dimension of pixel token
    patch_size=16,  # patch size
    pixel_size=4,  # pixel size
    depth=5,  # depth
    num_classes=1000,  # output number of classes
    attn_dropout=0.1,  # attention dropout
    ff_dropout=0.1,  # feedforward dropout
)

img = tf.random.uniform(shape=[5, 3, 256, 256])
logits = tnt(img) # (5, 1000)

Want to Contribute 🙋‍♂️ ?

Awesome! If you want to contribute to this project, you're always welcome! See Contributing Guidelines. You can also take a look at open issues for getting more information about current or upcoming tasks.

Want to discuss? 💬

Have any questions, doubts or want to present your opinions, views? You're always welcome. You can start discussions.

Citation

@misc{han2021transformer,
      title={Transformer in Transformer}, 
      author={Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
      year={2021},
      eprint={2103.00112},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

License

Copyright 2020 Rishit Dagli

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Comments
  • Add Unit Tests

    Add Unit Tests

    The tests should check for the rank and shape of the output tensors, the test should override tf.test.TestCase base class.

    • [x] #15
    • [x] #16
    • [x] #18
    • [x] #17

    Feel free to take inspiration from:

    • https://github.com/Rishit-dagli/Fast-Transformer/blob/main/fast_transformer/test_fast_transformer.py
    • For parametrization feel free to follow https://stackoverflow.com/a/34094/11878567, can be used in the exact same way with subTest in TensorFlow
    enhancement good first issue 
    opened by Rishit-dagli 3
  • Update Workflows to run tests

    Update Workflows to run tests

    This issue follows #11

    Update GitHub Workflows to:

    • [ ] Run Tests before uploading to PyPI
    • [ ] Create a workflow to run tests on commits

    Feel free to take inspiration from https://github.com/Rishit-dagli/Fast-Transformer/tree/main/.github/workflows

    enhancement good first issue 
    opened by Rishit-dagli 0
  • Creates an Attention layer

    Creates an Attention layer

    Verify output shapes just from the attention layer:

    import tensorflow as tf
    Attention(dim=256)(tf.random.normal([3,256,256]))
    
    # <tf.Tensor: shape=(3, 256, 256), dtype=float32,
    

    Closes #3

    opened by Rishit-dagli 0
  • Put together a TNT class

    Put together a TNT class

    Verify shapes:

    tnt = TNT(
        image_size=256,  # size of image
        patch_dim=512,  # dimension of patch token
        pixel_dim=24,  # dimension of pixel token
        patch_size=16,  # patch size
        pixel_size=4,  # pixel size
        depth=5,  # depth
        num_classes=1000,  # output number of classes
        attn_dropout=0.1,  # attention dropout
        ff_dropout=0.1,  # feedforward dropout
    )
    
    img = tf.random.uniform(shape=[1, 3, 256, 256])
    print(tnt(img).shape)
    
    # (1, 1000)
    ```
    opened by Rishit-dagli 0
  • Create an Attention layerr

    Create an Attention layerr

    Verify output shapes just from the attention layer:

    import tensorflow as tf
    Attention(dim=256)(tf.random.normal([3,256,256]))
    
    # <tf.Tensor: shape=(3, 256, 256), dtype=float32,
    
    opened by Rishit-dagli 0
  • Create a PreNorm layer

    Create a PreNorm layer

    Verify output shapes from this layer:

    import tensorflow as tf
    PreNorm(dim=1, fn=tf.keras.layers.Dense(5))(tf.random.normal([10, 1]))
    
    # <tf.Tensor: shape=(10, 1), dtype=float32,
    
    opened by Rishit-dagli 0
Releases(v0.2.0)
  • v0.2.0(Feb 2, 2022)

    This is an interesting release for the project, including a pre-trained model on ImageNet, reproducibility of paper results, tests, and end-to-end training.

    ✅ Bug Fixes / Improvements

    • Create an end-to-end training example demonstrating how to train a TNT model for image classification through a custom training loop on the TF Flowers dataset (#14)
    • Pre-trained model to reproduce the paper results have been made available (in this release as well as on TensorFlow Hub)
    • Create an off-the-shelf inference example, that highlights how you can directly use the pre-trained model made available
    • Unit Tests for the Attention class (#19)
    • Unit Tests for the main TNT class (#20)

    Full Changelog: https://github.com/Rishit-dagli/Transformer-in-Transformer/compare/v0.1.0...v0.2.0

    Source code(tar.gz)
    Source code(zip)
    tnt_s_patch16_224.tar.gz(84.42 MB)
  • v0.1.0(Dec 3, 2021)

    This is the initial release of TNT TensorFlow and implements Transformers in Transformers as a subclassed TensorFlow model.

    Classes

    • Attention: Implements attention as a TensorFlow Keras Layer making some modifications.
    • PreNorm: Normalize the activations of the previous layer for each given example in a batch independently and apply some function to it, implemented as a TensorFlow Keras Layer.
    • FeedForward: Create a FeedForward neural net with two Dense layers and GELU activation, implemented as a TensorFlow Keras Layer.
    • TNT: Implements the Transformers in Transformers model using all the other classes, and converts to logits. Implemented as a TensorFlow Keras Model.
    Source code(tar.gz)
    Source code(zip)
    tnt_s_patch16_224.tar.gz(84.42 MB)
Owner
Rishit Dagli
High School,TEDx,2xTED-Ed speaker | International Speaker | Microsoft Student Ambassador | Mentor, @TFUGMumbai | Organize @KotlinMumbai
Rishit Dagli
Computer vision - fun segmentation experience using classic and deep tools :)

Computer_Vision_Segmentation_Fun Segmentation of Images and Video. Tools: pytorch Models: Classic model - GrabCut Deep model - Deeplabv3_resnet101 Flo

Mor Ventura 1 Dec 18, 2021
FaceAnon - Anonymize people in images and videos using yolov5-crowdhuman

Face Anonymizer Blur faces from image and video files in /input/ folder. Require

22 Nov 03, 2022
[ WSDM '22 ] On Sampling Collaborative Filtering Datasets

On Sampling Collaborative Filtering Datasets This repository contains the implementation of many popular sampling strategies, along with various expli

Noveen Sachdeva 17 Dec 08, 2022
This is the code of using DQN to play Sekiro .

Update for using DQN to play sekiro 2021.2.2(English Version) This is the code of using DQN to play Sekiro . I am very glad to tell that I have writen

144 Dec 25, 2022
An experiment to bait a generalized frontrunning MEV bot

Honeypot 🍯 A simple experiment that: Creates a honeypot contract Baits a generalized fronturnning bot with a unique transaction Analyze bot behaviour

0x1355 14 Nov 24, 2022
Vehicles Counting using YOLOv4 + DeepSORT + Flask + Ngrok

A project for counting vehicles using YOLOv4 + DeepSORT + Flask + Ngrok

Duong Tran Thanh 37 Dec 16, 2022
DeepCO3: Deep Instance Co-segmentation by Co-peak Search and Co-saliency

[CVPR19] DeepCO3: Deep Instance Co-segmentation by Co-peak Search and Co-saliency (Oral paper) Authors: Kuang-Jui Hsu, Yen-Yu Lin, Yung-Yu Chuang PDF:

Kuang-Jui Hsu 139 Dec 22, 2022
Unsupervised MRI Reconstruction via Zero-Shot Learned Adversarial Transformers

Official TensorFlow implementation of the unsupervised reconstruction model using zero-Shot Learned Adversarial TransformERs (SLATER). (https://arxiv.

ICON Lab 22 Dec 22, 2022
Learning Versatile Neural Architectures by Propagating Network Codes

Learning Versatile Neural Architectures by Propagating Network Codes Mingyu Ding, Yuqi Huo, Haoyu Lu, Linjie Yang, Zhe Wang, Zhiwu Lu, Jingdong Wang,

Mingyu Ding 36 Dec 06, 2022
A simple code to convert image format and channel as well as resizing and renaming multiple images.

Rename-Resize-and-convert-multiple-images A simple code to convert image format and channel as well as resizing and renaming multiple images. This cod

Happy N. Monday 3 Feb 15, 2022
Geometric Deep Learning Extension Library for PyTorch

Documentation | Paper | Colab Notebooks | External Resources | OGB Examples PyTorch Geometric (PyG) is a geometric deep learning extension library for

Matthias Fey 16.5k Jan 08, 2023
RodoSol-ALPR Dataset

RodoSol-ALPR Dataset This dataset, called RodoSol-ALPR dataset, contains 20,000 images captured by static cameras located at pay tolls owned by the Ro

Rayson Laroca 45 Dec 15, 2022
SMD-Nets: Stereo Mixture Density Networks

SMD-Nets: Stereo Mixture Density Networks This repository contains a Pytorch implementation of "SMD-Nets: Stereo Mixture Density Networks" (CVPR 2021)

Fabio Tosi 115 Dec 26, 2022
Informal Persian Universal Dependency Treebank

Informal Persian Universal Dependency Treebank (iPerUDT) Informal Persian Universal Dependency Treebank, consisting of 3000 sentences and 54,904 token

Roya Kabiri 0 Jan 05, 2022
Source code for the paper "Periodic Traveling Waves in an Integro-Difference Equation With Non-Monotonic Growth and Strong Allee Effect"

Source code for the paper "Periodic Traveling Waves in an Integro-Difference Equation With Non-Monotonic Growth and Strong Allee Effect" by Michael Ne

M Nestor 1 Apr 19, 2022
Implementation for paper "STAR: A Structure-aware Lightweight Transformer for Real-time Image Enhancement" (ICCV 2021).

STAR-pytorch Implementation for paper "STAR: A Structure-aware Lightweight Transformer for Real-time Image Enhancement" (ICCV 2021). CVF (pdf) STAR-DC

43 Dec 21, 2022
Learning to Prompt for Continual Learning

Learning to Prompt for Continual Learning (L2P) Official Jax Implementation L2P is a novel continual learning technique which learns to dynamically pr

Google Research 207 Jan 06, 2023
This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

ICCV Workshop 2021 VTGAN This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

Sharif Amit Kamran 25 Dec 08, 2022
The code of Zero-shot learning for low-light image enhancement based on dual iteration

Zero-shot-dual-iter-LLE The code of Zero-shot learning for low-light image enhancement based on dual iteration. You can get the real night image tests

1 Mar 18, 2022
A hybrid framework (neural mass model + ML) for SC-to-FC prediction

The current workflow simulates brain functional connectivity (FC) from structural connectivity (SC) with a neural mass model. Gradient descent is applied to optimize the parameters in the neural mass

Yilin Liu 1 Jan 26, 2022