Implementation of the state-of-the-art vision transformers with tensorflow

Overview

ViT Tensorflow

This repository contains the tensorflow implementation of the state-of-the-art vision transformers (a category of computer vision models first introduced in An Image is worth 16 x 16 words). This repository is inspired from the work of lucidrains which is vit-pytorch. I hope you enjoy these implementations :)

Models

Requirements

pip install tensorflow

Vision Transformer

Vision transformer was introduced in An Image is worth 16 x 16 words. This model uses a Transformer encoder to classify images with pure attention and no convolution.

Usage

Defining the Model

from vit import ViT
import tensorflow as tf

vitClassifier = ViT(
                    num_classes=1000,
                    patch_size=16,
                    num_of_patches=(224//16)**2,
                    d_model=128,
                    heads=2,
                    num_layers=4,
                    mlp_rate=2,
                    dropout_rate=0.1,
                    prediction_dropout=0.3,
)
Params
  • num_classes: int
    number of classes used for the final classification head
  • patch_size: int
    patch_size used for the tokenization
  • num_of_patches: int
    number of patches after the tokenization which is used for the positional encoding, Generally it can be computed by the following formula (((h-patch_size)//patch_size) + 1)*(((w-patch_size)//patch_size) + 1) where h is the height of the image and w is the width of the image. In addition, when height and width of the image are devisable by the patch_size the following formula can be used as well (h//patch_size)*(w//patch_size)
  • d_model: int
    hidden dimension of the transformer encoder and the demnesion used for patch embedding
  • heads: int
    number of heads used for the multi-head attention mechanism
  • num_layers: int
    number of blocks in encoder transformer
  • mlp_rate: int
    the rate of expansion in the feed-forward block of each transformer block (the dimension after expansion is mlp_rate * d_model)
  • dropout_rate: float
    dropout rate used in the multi-head attention mechanism
  • prediction_dropout: float
    dropout rate used in the final prediction head of the model

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = vitClassifier(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

vitClassifier.compile(
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              metrics=[
                       tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                       tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
              ])

vitClassifier.fit(
              trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
              validation_data=valData, #The same as training
              epochs=100,)

Convolutional Vision Transformer

Convolutional Vision Transformer was introduced in here. This model uses a hierarchical (multi-stage) architecture with convolutional embeddings in the begining of each stage. it also uses Convolutional Transformer Blocks to improve the orginal vision transformer by adding CNNs inductive bias into the architecture.

Usage

Defining the Model

from cvt import CvT , CvTStage
import tensorflow as tf

cvtModel = CvT(
num_of_classes=1000, 
stages=[
        CvTStage(projectionDim=64, 
                 heads=1, 
                 embeddingWindowSize=(7 , 7), 
                 embeddingStrides=(4 , 4), 
                 layers=1,
                 projectionWindowSize=(3 , 3), 
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1),
        CvTStage(projectionDim=192,
                 heads=3,
                 embeddingWindowSize=(3 , 3), 
                 embeddingStrides=(2 , 2),
                 layers=1, 
                 projectionWindowSize=(3 , 3), 
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1),
        CvTStage(projectionDim=384,
                 heads=6,
                 embeddingWindowSize=(3 , 3),
                 embeddingStrides=(2 , 2),
                 layers=1,
                 projectionWindowSize=(3 , 3),
                 projectionStrides=(2 , 2), 
                 ffnRate=4,
                 dropoutRate=0.1)
],
dropout=0.5)
CvT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of CvTStage
    list of cvt stages
  • dropout: float
    dropout rate used for the prediction head
CvTStage Params
  • projectionDim: int
    dimension used for the multi-head attention mechanism and the convolutional embedding
  • heads: int
    number of heads in the multi-head attention mechanism
  • embeddingWindowSize: tuple(int , int)
    window size used for the convolutional emebdding
  • embeddingStrides: tuple(int , int)
    strides used for the convolutional embedding
  • layers: int
    number of convolutional transformer blocks
  • projectionWindowSize: tuple(int , int)
    window size used for the convolutional projection in each convolutional transformer block
  • projectionStrides: tuple(int , int)
    strides used for the convolutional projection in each convolutional transformer block
  • ffnRate: int
    expansion rate of the mlp block in each convolutional transformer block
  • dropoutRate: float
    dropout rate used in each convolutional transformer block

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = cvtModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

cvtModel.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        metrics=[
                 tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                 tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
        ])

cvtModel.fit(
        trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
        validation_data=valData, #The same as training
        epochs=100,)

Pyramid Vision Transformer V1

Pyramid Vision Transformer V1 was introduced in here. This model stacks multiple Transformer Encoders to form the first convolution-free multi-scale backbone for various visual tasks including Image Segmentation , Object Detection and etc. In addition to this a new attention mechanism called Spatial Reduction Attention (SRA) is also introduced in this paper to reduce the quadratic complexity of the multi-head attention mechansim.

Usage

Defining the Model

from pvt_v1 import PVT , PVTStage
import tensorflow as tf

pvtModel = PVT(
num_of_classes=1000, 
stages=[
        PVTStage(d_model=64,
                 patch_size=(2 , 2),
                 heads=1,
                 reductionFactor=2,
                 mlp_rate=2,
                 layers=2, 
                 dropout_rate=0.1),
        PVTStage(d_model=128,
                 patch_size=(2 , 2),
                 heads=2, 
                 reductionFactor=2, 
                 mlp_rate=2, 
                 layers=2, 
                 dropout_rate=0.1),
        PVTStage(d_model=320,
                 patch_size=(2 , 2),
                 heads=5, 
                 reductionFactor=2, 
                 mlp_rate=2, 
                 layers=2, 
                 dropout_rate=0.1),
],
dropout=0.5)
PVT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of PVTStage
    list of pvt stages
  • dropout: float
    dropout rate used for the prediction head
PVTStage Params
  • d_model: int
    dimension used for the SRA mechanism and the patch embedding
  • patch_size: tuple(int , int)
    window size used for the patch emebdding
  • heads: int
    number of heads in the SRA mechanism
  • reductionFactor: int
    reduction factor used for the down sampling of the K and V in the SRA mechanism
  • mlp_rate: int
    expansion rate used in the feed-forward block
  • layers: int
    number of transformer encoders
  • dropout_rate: float
    dropout rate used in each transformer encoder

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = pvtModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

pvtModel.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        metrics=[
                 tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                 tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
        ])

pvtModel.fit(
        trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
        validation_data=valData, #The same as training
        epochs=100,)

Pyramid Vision Transformer V2

Pyramid Vision Transformer V2 was introduced in here. This model is an improved version of the PVT V1. The improvements of this version are as follows:

  1. It uses overlapping patch embedding by using padded convolutions
  2. It uses convolutional feed-forward blocks which have a depth-wise convolution after the first fully-connected layer
  3. It uses a fixed pooling instead of convolutions for down sampling the K and V in the SRA attention mechanism (The new attention mechanism is called Linear SRA)

Usage

Defining the Model

from pvt_v2 import PVTV2 , PVTV2Stage
import tensorflow as tf

pvtV2Model = PVTV2(
num_of_classes=1000, 
stages=[
        PVTV2Stage(d_model=64,
                   windowSize=(2 , 2), 
                   heads=1,
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2, 
                   dropout_rate=0.1),
        PVTV2Stage(d_model=128, 
                   windowSize=(2 , 2),
                   heads=2,
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2,
                   dropout_rate=0.1),
        PVTV2Stage(d_model=320,
                   windowSize=(2 , 2), 
                   heads=5, 
                   poolingSize=(7 , 7), 
                   mlp_rate=2, 
                   mlp_windowSize=(3 , 3), 
                   layers=2, 
                   dropout_rate=0.1),
],
dropout=0.5)
PVT Params
  • num_of_classes: int
    number of classes used in the final prediction layer
  • stages: list of PVTV2Stage
    list of pvt v2 stages
  • dropout: float
    dropout rate used for the prediction head
PVTStage Params
  • d_model: int
    dimension used for the Linear SRA mechanism and the convolutional patch embedding
  • windowSize: tuple(int , int)
    window size used for the convolutional patch emebdding
  • heads: int
    number of heads in the Linear SRA mechanism
  • poolingSize: tuple(int , int)
    size of the K and V after the fixed pooling
  • mlp_rate: int
    expansion rate used in the convolutional feed-forward block
  • mlp_windowSize: tuple(int , int)
    the window size used for the depth-wise convolution in the convolutional feed-forward block
  • layers: int
    number of transformer encoders
  • dropout_rate: float
    dropout rate used in each transformer encoder

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = pvtV2Model(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

pvtV2Model.compile(
          loss=tf.keras.losses.SparseCategoricalCrossentropy(),
          optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
          metrics=[
                   tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                   tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
          ])

pvtV2Model.fit(
          trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b,))
          validation_data=valData, #The same as training
          epochs=100,)

DeiT

DeiT was introduced in Training Data-Efficient Image Transformers & Distillation Through Attention. Since original vision transformer is data hungry due to the lack of existance of any inductive bias (unlike CNNs) a lot of data is required to train original vision transformer in order to surpass the state-of-the-art CNNs such as Resnet. Therefore, in this paper authors used a pre-trained CNN such as resent during training and used a sepcial loss function to perform distillation through attention.

Usage

Defining the Model

from deit import DeiT
import tensorflow as tf

teacherModel = tf.keras.applications.ResNet50(include_top=True, 
                                              weights="imagenet", 
                                              input_shape=(224 , 224 , 3))

deitModel = DeiT(
                 num_classes=1000,
                 patch_size=16,
                 num_of_patches=(224//16)**2,
                 d_model=128,
                 heads=2,
                 num_layers=4,
                 mlp_rate=2,
                 teacherModel=teacherModel,
                 temperature=1.0, 
                 alpha=0.5,
                 hard=False, 
                 dropout_rate=0.1,
                 prediction_dropout=0.3,
)
Params
  • num_classes: int
    number of classes used for the final classification head
  • patch_size: int
    patch_size used for the tokenization
  • num_of_patches: int
    number of patches after the tokenization which is used for the positional encoding, Generally it can be computed by the following formula (((h-patch_size)//patch_size) + 1)*(((w-patch_size)//patch_size) + 1) where h is the height of the image and w is the width of the image. In addition, when height and width of the image are devisable by the patch_size the following formula can be used as well (h//patch_size)*(w//patch_size)
  • d_model: int
    hidden dimension of the transformer encoder and the demnesion used for patch embedding
  • heads: int
    number of heads used for the multi-head attention mechanism
  • num_layers: int
    number of blocks in encoder transformer
  • mlp_rate: int
    the rate of expansion in the feed-forward block of each transformer block (the dimension after expansion is mlp_rate * d_model)
  • teacherModel: Tensorflow Model
    the teacherModel used for the distillation during training, This model is a pre-trained CNN model with the same input_shape and output_shape as the Transformer
  • temperature: float
    the temperature parameter in the loss
  • alpha: float
    the coefficient balancing the Kullback–Leibler divergence loss (KL) and the cross-entropy loss
  • hard: bool
    indicates using Hard-label distillation or Soft distillation
  • dropout_rate: float
    dropout rate used in the multi-head attention mechanism
  • prediction_dropout: float
    dropout rate used in the final prediction head of the model

Inference

sampleInput = tf.random.normal(shape=(1 , 224 , 224 , 3))
output = deitModel(sampleInput , training=False)
print(output.shape) # (1 , 1000)

Training

#Note that the loss is defined inside the model and no loss should be passed here
deitModel.compile(
         optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
         metrics=[
                  tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                  tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5 , name="top_5_accuracy"),
         ])

deitModel.fit(
         trainingData, #Tensorflow dataset of images and labels in shape of ((b , h , w , 3) , (b , num_classes))
         validation_data=valData, #The same as training
         epochs=100,)
Owner
Mohammadmahdi NouriBorji
Mohammadmahdi NouriBorji
Official implementation of the method ContIG, for self-supervised learning from medical imaging with genomics

ContIG: Self-supervised Multimodal Contrastive Learning for Medical Imaging with Genetics This is the code implementation of the paper "ContIG: Self-s

Digital Health & Machine Learning 22 Dec 13, 2022
chainladder - Property and Casualty Loss Reserving in Python

chainladder (python) chainladder - Property and Casualty Loss Reserving in Python This package gets inspiration from the popular R ChainLadder package

Casualty Actuarial Society 130 Dec 07, 2022
CCNet: Criss-Cross Attention for Semantic Segmentation (TPAMI 2020 & ICCV 2019).

CCNet: Criss-Cross Attention for Semantic Segmentation Paper Links: Our most recent TPAMI version with improvements and extensions (Earlier ICCV versi

Zilong Huang 1.3k Dec 27, 2022
DeepCAD: A Deep Generative Network for Computer-Aided Design Models

DeepCAD This repository provides source code for our paper: DeepCAD: A Deep Generative Network for Computer-Aided Design Models Rundi Wu, Chang Xiao,

Rundi Wu 85 Dec 31, 2022
Deep Implicit Moving Least-Squares Functions for 3D Reconstruction

DeepMLS: Deep Implicit Moving Least-Squares Functions for 3D Reconstruction This repository contains the implementation of the paper: Deep Implicit Mo

103 Dec 22, 2022
Real-time face detection and emotion/gender classification using fer2013/imdb datasets with a keras CNN model and openCV.

Real-time face detection and emotion/gender classification using fer2013/imdb datasets with a keras CNN model and openCV.

Octavio Arriaga 5.3k Dec 30, 2022
Open-source code for Generic Grouping Network (GGN, CVPR 2022)

Open-World Instance Segmentation: Exploiting Pseudo Ground Truth From Learned Pairwise Affinity Pytorch implementation for "Open-World Instance Segmen

Meta Research 99 Dec 06, 2022
Code for paper "Multi-level Disentanglement Graph Neural Network"

Multi-level Disentanglement Graph Neural Network (MD-GNN) This is a PyTorch implementation of the MD-GNN, and the code includes the following modules:

Lirong Wu 6 Dec 29, 2022
Official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model.

This is the official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model.

peng gao 42 Nov 26, 2022
Certifiable Outlier-Robust Geometric Perception

Certifiable Outlier-Robust Geometric Perception About This repository holds the implementation for certifiably solving outlier-robust geometric percep

83 Dec 31, 2022
The BCNet related data and inference model.

BCNet This repository includes the some source code and related dataset of paper BCNet: Learning Body and Cloth Shape from A Single Image, ECCV 2020,

81 Dec 12, 2022
A Graph Neural Network Tool for Recovering Dense Sub-graphs in Random Dense Graphs.

PYGON A Graph Neural Network Tool for Recovering Dense Sub-graphs in Random Dense Graphs. Installation This code requires to install and run the graph

Yoram Louzoun's Lab 0 Jun 25, 2021
Model that predicts the probability of a Twitter user being anti-vaccination.

stylebody {text-align: justify}/style AVAXTAR: Anti-VAXx Tweet AnalyzeR AVAXTAR is a python package to identify anti-vaccine users on twitter. The

10 Sep 27, 2022
Improving Calibration for Long-Tailed Recognition (CVPR2021)

MiSLAS Improving Calibration for Long-Tailed Recognition Authors: Zhisheng Zhong, Jiequan Cui, Shu Liu, Jiaya Jia [arXiv] [slide] [BibTeX] Introductio

DV Lab 116 Dec 20, 2022
Real-time LIDAR-based Urban Road and Sidewalk detection for Autonomous Vehicles šŸš—

urban_road_filter: a real-time LIDAR-based urban road and sidewalk detection algorithm for autonomous vehicles Dependency ROS (tested with Kinetic and

JKK - Vehicle Industry Research Center 180 Dec 12, 2022
Code for "Adversarial Attack Generation Empowered by Min-Max Optimization", NeurIPS 2021

Min-Max Adversarial Attacks [Paper] [arXiv] [Video] [Slide] Adversarial Attack Generation Empowered by Min-Max Optimization Jingkang Wang, Tianyun Zha

Jingkang Wang 12 Nov 23, 2022
RMTD: Robust Moving Target Defence Against False Data Injection Attacks in Power Grids

RMTD: Robust Moving Target Defence Against False Data Injection Attacks in Power Grids Real-time detection performance. This repo contains the code an

0 Nov 10, 2021
Jingju baseline - A baseline model of our project of Beijing opera script generation

Jingju Baseline It is a baseline of our project about Beijing opera script gener

midon 1 Jan 14, 2022
ā€œRobust Lightweight Facial Expression Recognition Network with Label Distribution Trainingā€, AAAI 2021.

EfficientFace Zengqun Zhao, Qingshan Liu, Feng Zhou. "Robust Lightweight Facial Expression Recognition Network with Label Distribution Training". AAAI

Zengqun Zhao 119 Jan 08, 2023
Code release for SLIP Self-supervision meets Language-Image Pre-training

SLIP: Self-supervision meets Language-Image Pre-training What you can find in this repo: Pre-trained models (with ViT-Small, Base, Large) and code to

Meta Research 621 Dec 31, 2022