Explainability for Vision Transformers (in PyTorch)

Overview

Explainability for Vision Transformers (in PyTorch)

This repository implements methods for explainability in Vision Transformers.

See also https://jacobgil.github.io/deeplearning/vision-transformer-explainability

Currently implemented:

  • Attention Rollout.

  • Gradient Attention Rollout for class specific explainability. This is our attempt to further build upon and improve Attention Rollout.

  • TBD Attention flow is work in progress.

Includes some tweaks and tricks to get it working:

  • Different Attention Head fusion methods,
  • Removing the lowest attentions.

Usage

  • From code
from vit_grad_rollout import VITAttentionGradRollout

model = torch.hub.load('facebookresearch/deit:main', 
'deit_tiny_patch16_224', pretrained=True)
grad_rollout = VITAttentionGradRollout(model, discard_ratio=0.9, head_fusion='max')
mask = grad_rollout(input_tensor, category_index=243)
  • From the command line:
python vit_explain.py --image_path  --head_fusion  --discard_ratio  --category_index 

If category_index isn't specified, Attention Rollout will be used, otherwise Gradient Attention Rollout will be used.

Notice that by default, this uses the 'Tiny' model from Training data-efficient image transformers & distillation through attention hosted on torch hub.

Where did the Transformer pay attention to in this image?

Image Vanilla Attention Rollout With discard_ratio+max fusion

Gradient Attention Rollout for class specific explainability

The Attention that flows in the transformer passes along information belonging to different classes. Gradient roll out lets us see what locations the network paid attention too, but it tells us nothing about if it ended up using those locations for the final classification.

We can multiply the attention with the gradient of the target class output, and take the average among the attention heads (while masking out negative attentions) to keep only attention that contributes to the target category (or categories).

Where does the Transformer see a Dog (category 243), and a Cat (category 282)?

Where does the Transformer see a Musket dog (category 161) and a Parrot (category 87):

Tricks and Tweaks to get this working

Filtering the lowest attentions in every layer

--discard_ratio

Removes noise by keeping the strongest attentions.

Results for dIfferent values:

Different Attention Head Fusions

The Attention Rollout method suggests taking the average attention accross the attention heads,

but emperically it looks like taking the Minimum value, Or the Maximum value combined with --discard_ratio, works better.

--head_fusion

Image Mean Fusion Min Fusion

References

Requirements

pip install timm

Owner
Jacob Gildenblat
Machine learning / Computer Vision developer.
Jacob Gildenblat
This repository contains the reference implementation for our proposed Convolutional CRFs.

ConvCRF This repository contains the reference implementation for our proposed Convolutional CRFs in PyTorch (Tensorflow planned). The two main entry-

Marvin Teichmann 553 Dec 07, 2022
使用yolov5训练自己数据集(详细过程)并通过flask部署

使用yolov5训练自己的数据集(详细过程)并通过flask部署 依赖库 torch torchvision numpy opencv-python lxml tqdm flask pillow tensorboard matplotlib pycocotools Windows,请使用 pycoc

HB.com 19 Dec 28, 2022
Recommendation algorithms for large graphs

Fast recommendation algorithms for large graphs based on link analysis. License: Apache Software License Author: Emmanouil (Manios) Krasanakis Depende

Multimedia Knowledge and Social Analytics Lab 27 Jan 07, 2023
Python and C++ implementation of "MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation". Accepted at LXCV @ CVPR 2021.

MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation This is a PyTorch and LibTorch implementation of MarkerPose: a

Jhacson Meza 47 Nov 18, 2022
Code for the paper "SmoothMix: Training Confidence-calibrated Smoothed Classifiers for Certified Robustness" (NeurIPS 2021)

SmoothMix: Training Confidence-calibrated Smoothed Classifiers for Certified Robustness (NeurIPS2021) This repository contains code for the paper "Smo

Jongheon Jeong 17 Dec 27, 2022
Calibrate your listeners! Robust communication-based training for pragmatic speakers. Findings of EMNLP 2021.

Calibrate your listeners! Robust communication-based training for pragmatic speakers Rose E. Wang, Julia White, Jesse Mu, Noah D. Goodman Findings of

Rose E. Wang 3 Apr 02, 2022
PyTorch implementation of adversarial patch

adversarial-patch PyTorch implementation of adversarial patch This is an implementation of the Adversarial Patch paper. Not official and likely to hav

Jamie Hayes 172 Nov 29, 2022
Code for the AAAI 2022 paper "Zero-Shot Cross-Lingual Machine Reading Comprehension via Inter-Sentence Dependency Graph".

multilingual-mrc-isdg Code for the AAAI 2022 paper "Zero-Shot Cross-Lingual Machine Reading Comprehension via Inter-Sentence Dependency Graph". This r

Liyan 5 Dec 07, 2022
Meta-learning for NLP

Self-Supervised Meta-Learning for Few-Shot Natural Language Classification Tasks Code for training the meta-learning models and fine-tuning on downstr

IESL 43 Nov 08, 2022
YOLOv4-v3 Training Automation API for Linux

This repository allows you to get started with training a state-of-the-art Deep Learning model with little to no configuration needed! You provide your labeled dataset or label your dataset using our

BMW TechOffice MUNICH 626 Dec 31, 2022
Cobalt Strike teamserver detection.

Cobalt-Strike-det Cobalt Strike teamserver detection. usage: cobaltstrike_verify.py [-l TARGETS] [-t THREADS] optional arguments: -h, --help show this

TimWhite 17 Sep 27, 2022
An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.

NNI Doc | 简体中文 NNI (Neural Network Intelligence) is a lightweight but powerful toolkit to help users automate Feature Engineering, Neural Architecture

Microsoft 12.4k Dec 31, 2022
A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.

TorchRL Disclaimer This library is not officially released yet and is subject to change. The features are available before an official release so that

Meta Research 860 Jan 07, 2023
LoveDA: A Remote Sensing Land-Cover Dataset for Domain Adaptive Semantic Segmentation (NeurIPS2021 Benchmark and Dataset Track)

LoveDA: A Remote Sensing Land-Cover Dataset for Domain Adaptive Semantic Segmentation by Junjue Wang, Zhuo Zheng, Ailong Ma, Xiaoyan Lu, and Yanfei Zh

Kingdrone 174 Dec 22, 2022
A Large-Scale Dataset for Spinal Vertebrae Segmentation in Computed Tomography

A Large-Scale Dataset for Spinal Vertebrae Segmentation in Computed Tomography

ICT.MIRACLE lab 75 Dec 26, 2022
A general framework for deep learning experiments under PyTorch based on pytorch-lightning

torchx Torchx is a general framework for deep learning experiments under PyTorch based on pytorch-lightning. TODO list gan-like training wrapper text

Yingtian Liu 6 Mar 17, 2022
Code for the paper "Next Generation Reservoir Computing"

Next Generation Reservoir Computing This is the code for the results and figures in our paper "Next Generation Reservoir Computing". They are written

OSU QuantInfo Lab 105 Dec 20, 2022
[NeurIPS 2021] "Delayed Propagation Transformer: A Universal Computation Engine towards Practical Control in Cyber-Physical Systems"

Delayed Propagation Transformer: A Universal Computation Engine towards Practical Control in Cyber-Physical Systems Introduction Multi-agent control i

VITA 6 May 05, 2022
Robust Self-augmentation for NER with Meta-reweighting

Robust Self-augmentation for NER with Meta-reweighting

Lam chi 17 Nov 22, 2022
Unsupervised Feature Loss (UFLoss) for High Fidelity Deep learning (DL)-based reconstruction

Unsupervised Feature Loss (UFLoss) for High Fidelity Deep learning (DL)-based reconstruction Official github repository for the paper High Fidelity De

28 Dec 16, 2022