Implementation of Perceiver, General Perception with Iterative Attention in TensorFlow

Overview

Perceiver Twitter

PyPI Lint with Black⬛ Upload Python Package DOI Code style: black

GitHub License GitHub stars GitHub followers Twitter Follow

This Python package implements Perceiver: General Perception with Iterative Attention by Andrew Jaegle in TensorFlow. This model builds on top of Transformers such that the data only enters through the cross attention mechanism (see figure) and allow it to scale to hundreds of thousands of inputs, like ConvNets. This, in part also solves the Transformers Quadratic compute and memory bottleneck.

Yannic Kilcher's video was very helpful.

Installation

Run the following to install:

pip install perceiver

Developing perceiver

To install perceiver, along with tools you need to develop and test, run the following in your virtualenv:

git clone https://github.com/Rishit-dagli/Perceiver.git
# or clone your own fork

cd perceiver
pip install -e .[dev]

A bit about Perceiver

The Perceiver model aims to deal with arbitrary configurations of different modalities using a single transformer-based architecture. Transformers are often flexible and make few assumptions about their inputs, but that also scale quadratically with the number of inputs in terms of both memory and computation. This model proposes a mechanism that makes it possible to deal with high-dimensional inputs, while retaining the expressivity and flexibility to deal with arbitrary input configurations.

The idea here is to introduce a small set of latent units that forms an attention bottleneck through which the inputs must pass. This avoids the quadratic scaling problem of all-to-all attention of a classical transformer. The model can be seen as performing a fully end-to-end clustering of the inputs, with the latent units as the cluster centres, leveraging a highly asymmetric crossattention layer. For spatial information the authors compensate for the lack of explicit grid structures in our model by associating Fourier feature encodings.

Usage

from perceiver import Perceiver
import tensorflow as tf

model = Perceiver(
    input_channels = 3,          # number of channels for each token of the input
    input_axis = 2,              # number of axis for input data (2 for images, 3 for video)
    num_freq_bands = 6,          # number of freq bands, with original value (2 * K + 1)
    max_freq = 10.,              # maximum frequency, hyperparameter depending on how fine the data is
    depth = 6,                   # depth of net
    num_latents = 256,           # number of latents
    latent_dim = 512,            # latent dimension
    cross_heads = 1,             # number of heads for cross attention. paper said 1
    latent_heads = 8,            # number of heads for latent self attention, 8
    cross_dim_head = 64,
    latent_dim_head = 64,
    num_classes = 1000,          # output number of classes
    attn_dropout = 0.,
    ff_dropout = 0.,
)

img = tf.random.normal([1, 224, 224, 3]) # replicating 1 imagenet image
model(img) # (1, 1000)

About the notebooks

perceiver_example

Open In Colab Binder

This notebook installs the perceiver package and shows an example of running it on a single imagenet image ([1, 224, 224, 3]) with 1000 classes to demonstarte the working of this model.

Want to Contribute 🙋‍♂️ ?

Awesome! If you want to contribute to this project, you're always welcome! See Contributing Guidelines. You can also take a look at open issues for getting more information about current or upcoming tasks.

Want to discuss? 💬

Have any questions, doubts or want to present your opinions, views? You're always welcome. You can start discussions.

Citations

@misc{jaegle2021perceiver,
    title   = {Perceiver: General Perception with Iterative Attention},
    author  = {Andrew Jaegle and Felix Gimeno and Andrew Brock and Andrew Zisserman and Oriol Vinyals and Joao Carreira},
    year    = {2021},
    eprint  = {2103.03206},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • error with tf2.4.1

    error with tf2.4.1

    Hello Rishit,

    thank you for your Perceiver implementation! I have two notes, I am not very familiar with tf2 though. You define and call a tf.keras.Sequential model here https://github.com/Rishit-dagli/Perceiver/blob/4d3b9b0514da4fb623d178e3e70df1836ebad5ba/perceiver/perceiver.py#L106 For my version of tf at least this throws an error, I think it should be defined once in __init__ and then just called in call.

    And just above it, you compute data but then you don't pass it to self.model. Is that correct?

    bug 
    opened by abred 3
  • Training code

    Training code

    Hi there,

    I've tried to set up a standard MNIST training over the last few days using the Perceiver code provided here. So far, I've not been able to come up with any solution where the model actually learns anything. A major problem so far has been the way the model is written with no support for model.fit() and the whole functional API.

    Do you happen to have any training example code for your model which you could provide here in this repo? MNIST as the default starting point would be nice, but anything would do the job as well :)

    question 
    opened by tpetri94 2
  • Create a FeedForward layer

    Create a FeedForward layer

    Create a simple FeedForward layer as a tf.keras.layers.Layer which should essentially contain a Dense layer with the modified GELU activation (#2 ), optionally I could also include a dropout layer and another Dense layer which should have the number of neurons equal to the dimension

    opened by Rishit-dagli 0
  • Implement a PreNorm layer

    Implement a PreNorm layer

    Create a Normalization layer from the tf.keras.layerr.Layers. This should essentially figure out the right axis and implement layer normalization on it.

    opened by Rishit-dagli 0
  • Don't pin TensorFlow version to a specific number

    Don't pin TensorFlow version to a specific number

    Hello,

    In setup.py you should change "tensorflow~=2.4.0" to " "tensorflow>2.4.0" to ensure any version above the minimal one is used.

    bug 
    opened by ebursztein 0
Releases(v0.1.2)
Owner
Rishit Dagli
High School,TEDx,2xTED-Ed speaker | International Speaker | Microsoft Student Ambassador | Mentor, @TFUGMumbai | Organize @KotlinMumbai
Rishit Dagli
GraPE is a Rust/Python library for high-performance Graph Processing and Embedding.

GraPE GraPE (Graph Processing and Embedding) is a fast graph processing and embedding library, designed to scale with big graphs and to run on both of

AnacletoLab 194 Dec 29, 2022
FPSAutomaticAiming——基于YOLOV5的FPS类游戏自动瞄准AI

FPSAutomaticAiming——基于YOLOV5的FPS类游戏自动瞄准AI 声明: 本项目仅限于学习交流,不可用于非法用途,包括但不限于:用于游戏外挂等,使用本项目产生的任何后果与本人无关! 简介 本项目基于yolov5,实现了一款FPS类游戏(CF、CSGO等)的自瞄AI,本项目旨在使用现

Fabian 246 Dec 28, 2022
Breast-Cancer-Prediction

Breast-Cancer-Prediction Trying to predict whether the cancer is benign or malignant using REGRESSION MODELS in Python. Team Members NAME ROLL-NUMBER

Shyamdev Krishnan J 3 Feb 18, 2022
Official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model.

This is the official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model.

peng gao 42 Nov 26, 2022
Data cleaning, missing value handle, EDA use in this project

Lending Club Case Study Project Brief Solving this assignment will give you an idea about how real business problems are solved using EDA. In this cas

Dhruvil Sheth 1 Jan 05, 2022
Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020)

GraspNet Baseline Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020). [paper] [dataset] [API] [do

GraspNet 209 Dec 29, 2022
An example of semantic segmentation using tensorflow in eager execution.

Semantic segmentation using Tensorflow eager execution Requirement Python 2.7+ Tensorflow-gpu OpenCv H5py Scikit-learn Numpy Imgaug Train with eager e

Iñigo Alonso Ruiz 25 Sep 29, 2022
Code repository for "Stable View Synthesis".

Stable View Synthesis Code repository for "Stable View Synthesis". Setup Install the following Python packages in your Python environment - numpy (1.1

Intelligent Systems Lab Org 195 Dec 24, 2022
Registration Loss Learning for Deep Probabilistic Point Set Registration

RLLReg This repository contains a Pytorch implementation of the point set registration method RLLReg. Details about the method can be found in the 3DV

Felix Järemo Lawin 35 Nov 02, 2022
Repository for the paper "From global to local MDI variable importances for random forests and when they are Shapley values"

From global to local MDI variable importances for random forests and when they are Shapley values Antonio Sutera ( Antonio Sutera 3 Feb 23, 2022

[ECCV 2020] XingGAN for Person Image Generation

Contents XingGAN or CrossingGAN Installation Dataset Preparation Generating Images Using Pretrained Model Train and Test New Models Evaluation Acknowl

Hao Tang 218 Oct 29, 2022
EquiBind: Geometric Deep Learning for Drug Binding Structure Prediction

EquiBind: geometric deep learning for fast predictions of the 3D structure in which a small molecule binds to a protein

Hannes Stärk 355 Jan 03, 2023
A TensorFlow 2.x implementation of Masked Autoencoders Are Scalable Vision Learners

Masked Autoencoders Are Scalable Vision Learners A TensorFlow implementation of Masked Autoencoders Are Scalable Vision Learners [1]. Our implementati

Aritra Roy Gosthipaty 59 Dec 10, 2022
This repository is for our EMNLP 2021 paper "Automated Generation of Accurate & Fluent Medical X-ray Reports"

Introduction: X-Ray Report Generation This repository is for our EMNLP 2021 paper "Automated Generation of Accurate & Fluent Medical X-ray Reports". O

no name 36 Dec 16, 2022
Pytorch implementation for "Distribution-Balanced Loss for Multi-Label Classification in Long-Tailed Datasets" (ECCV 2020 Spotlight)

Distribution-Balanced Loss [Paper] The implementation of our paper Distribution-Balanced Loss for Multi-Label Classification in Long-Tailed Datasets (

Tong WU 304 Dec 22, 2022
Si Adek Keras is software VR dangerous object detection.

Si Adek Python Keras Sistem Informasi Deteksi Benda Berbahaya Keras Python. Version 1.0 Developed by Ananda Rauf Maududi. Developed date: 24 November

Ananda Rauf 1 Dec 21, 2021
Matching python environment code for Lux AI 2021 Kaggle competition, and a gym interface for RL models.

Lux AI 2021 python game engine and gym This is a replica of the Lux AI 2021 game ported directly over to python. It also sets up a classic Reinforceme

Geoff McDonald 74 Nov 03, 2022
Predict the latency time of the deep learning models

Deep Neural Network Prediction Step 1. Genernate random parameters and Run them sequentially : $ python3 collect_data.py -gp -ep -pp -pl pooling -num

QAQ 1 Nov 12, 2021
PaRT: Parallel Learning for Robust and Transparent AI

PaRT: Parallel Learning for Robust and Transparent AI This repository contains the code for PaRT, an algorithm for training a base network on multiple

Mahsa 0 May 02, 2022
Codes for the paper Contrast and Mix: Temporal Contrastive Video Domain Adaptation with Background Mixing

Contrast and Mix (CoMix) The repository contains the codes for the paper Contrast and Mix: Temporal Contrastive Video Domain Adaptation with Backgroun

Computer Vision and Intelligence Research (CVIR) 13 Dec 10, 2022