My implementation of DeepMind's Perceiver

Overview

DeepMind Perceiver (in PyTorch)

Disclaimer: This is not official and I'm not affiliated with DeepMind.

My implementation of the Perceiver: General Perception with Iterative Attention. You can read more about the model on DeepMind's website.

I trained an MNIST model which you can find in models/mnist.pkl or by using perceiver.load_mnist_model(). It gets 96.02% on the test-data.

Getting started

To run this you need PyTorch installed:

pip3 install torch

From perceiver you can import Perceiver or PerceiverLogits.

Then you can use it as such (or look in examples.ipynb):

from perceiver import Perceiver

model = Perceiver(
    input_channels, # <- How many channels in the input? E.g. 3 for RGB.
    input_shape, # <- How big is the input in the different dimensions? E.g. (28, 28) for MNIST
    fourier_bands=4, # <- How many bands should the positional encoding have?
    latents=64, # <- How many latent vectors?
    d_model=32, # <- Model dimensionality. Every pixel/token/latent vector will have this size.
    heads=8, # <- How many heads in self-attention? Cross-attention always has 1 head.
    latent_blocks=6, # <- How much latent self-attention for each cross attention with the input?
    dropout=0.1, # <- Dropout
    layers=8, # <- This will become two unique layer-blocks: layer 1 and layer 2-8 (using weight sharing).
)

The above model outputs the latents after the final layer. If you want logits instead, use the following model:

from perceiver import PerceiverLogits

model = PerceiverLogits(
    input_channels, # <- How many channels in the input? E.g. 3 for RGB.
    input_shape, # <- How big is the input in the different dimensions? E.g. (28, 28) for MNIST
    output_features, # <- How many different classes? E.g. 10 for MNIST.
    fourier_bands=4, # <- How many bands should the positional encoding have?
    latents=64, # <- How many latent vectors?
    d_model=32, # <- Model dimensionality. Every pixel/token/latent vector will have this size.
    heads=8, # <- How many heads in self-attention? Cross-attention always has 1 head.
    latent_blocks=6, # <- How much latent self-attention for each cross attention with the input?
    dropout=0.1, # <- Dropout
    layers=8, # <- This will become two unique layer-blocks: layer 1 and layer 2-8 (using weight sharing).
)

To use my pre-trained MNIST model (not very good):

from perceiver import load_mnist_model

model = load_mnist_model()

TODO:

  • Positional embedding generalized to n dimensions (with fourier features)
  • Train other models (like CIFAR-100 or something not in the image domain)
  • Type indication
  • Unit tests for components of model
  • Package
Owner
Louis Arge
Experienced full-stack developer. Self-studying machine learning.
Louis Arge
Global-Local Context Network for Person Search

Global-Local Context Network for Person Search Abstract: Person search aims to jointly localize and identify a query person from natural, uncropped im

Peng Zheng 15 Oct 17, 2022
Tensorflow 2 implementations of the C-SimCLR and C-BYOL self-supervised visual representation methods from "Compressive Visual Representations" (NeurIPS 2021)

Compressive Visual Representations This repository contains the source code for our paper, Compressive Visual Representations. We developed informatio

Google Research 30 Nov 23, 2022
CurriculumNet: Weakly Supervised Learning from Large-Scale Web Images

CurriculumNet Introduction This repo contains related code and models from the ECCV 2018 CurriculumNet paper. CurriculumNet is a new training strategy

156 Jul 04, 2022
Transformers based fully on MLPs

Awesome MLP-based Transformers papers An up-to-date list of Transformers based fully on MLPs without attention! Why this repo? After transformers and

Fawaz Sammani 35 Dec 30, 2022
Streamlit component for TensorBoard, TensorFlow's visualization toolkit

streamlit-tensorboard This is a work-in-progress, providing a function to embed TensorBoard, TensorFlow's visualization toolkit, in Streamlit apps. In

Snehan Kekre 27 Nov 13, 2022
A Low Complexity Speech Enhancement Framework for Full-Band Audio (48kHz) based on Deep Filtering.

DeepFilterNet A Low Complexity Speech Enhancement Framework for Full-Band Audio (48kHz) based on Deep Filtering. libDF contains Rust code used for dat

Hendrik Schröter 292 Dec 25, 2022
Incremental Cross-Domain Adaptation for Robust Retinopathy Screening via Bayesian Deep Learning

Incremental Cross-Domain Adaptation for Robust Retinopathy Screening via Bayesian Deep Learning Update (September 18th, 2021) A supporting document de

Taimur Hassan 1 Mar 16, 2022
Özlem Taşkın 0 Feb 23, 2022
Hashformers is a framework for hashtag segmentation with transformers.

Hashtag segmentation is the task of automatically inserting the missing spaces between the words in a hashtag. Hashformers applies Transformer models

Ruan Chaves 41 Nov 09, 2022
An self sufficient AI that crawls the web to learn how to generate art from keywords

Roxx-IO - The Smart Artist AI! TO DO / IDEAS Implement Web-Scraping Functionality Figure out a less annoying (and an off button for it) text to speech

Tatz 5 Mar 21, 2022
PyTorch implementation of TSception V2 using DEAP dataset

TSception This is the PyTorch implementation of TSception V2 using DEAP dataset in our paper: Yi Ding, Neethu Robinson, Su Zhang, Qiuhao Zeng, Cuntai

Yi Ding 27 Dec 15, 2022
scAR (single-cell Ambient Remover) is a package for data denoising in single-cell omics.

scAR scAR (single cell Ambient Remover) is a package for denoising multiple single cell omics data. It can be used for multiple tasks, such as, sgRNA

19 Nov 28, 2022
[WWW 2022] Zero-Shot Stance Detection via Contrastive Learning

PT-HCL for Zero-Shot Stance Detection The code of this repository is constantly being updated... Please look forward to it! Introduction This reposito

Akuchi 12 Dec 21, 2022
This repository contains code released by Google Research.

This repository contains code released by Google Research.

Google Research 26.6k Dec 31, 2022
Peek-a-Boo: What (More) is Disguised in a Randomly Weighted Neural Network, and How to Find It Efficiently

Peek-a-Boo: What (More) is Disguised in a Randomly Weighted Neural Network, and How to Find It Efficiently This repository is the official implementat

VITA 4 Dec 20, 2022
Detection of PCBA defect

Detection_of_PCBA_defect Detection_of_PCBA_defect Use yolov5 to train. $pip install -r requirements.txt Detect.py will detect file(jpg,mp4...) in cu

6 Nov 28, 2022
Discord Multi Tool that focuses on design and easy usage

Multi-Tool-v1.0 Discord Multi Tool that focuses on design and easy usage Delete webhook Block all friends Spam webhook Modify webhook Webhook info Tok

Lodi#0001 24 May 23, 2022
Impelmentation for paper Feature Generation and Hypothesis Verification for Reliable Face Anti-Spoofing

FGHV Impelmentation for paper Feature Generation and Hypothesis Verification for Reliable Face Anti-Spoofing Requirements Python 3.6 Pytorch 1.5.0 Cud

5 Jun 02, 2022
This repository contain code on Novelty-Driven Binary Particle Swarm Optimisation for Truss Optimisation Problems.

This repository contain code on Novelty-Driven Binary Particle Swarm Optimisation for Truss Optimisation Problems. The main directory include the code

0 Dec 23, 2021
A python-image-classification web application project, written in Python and served through the Flask Microframework

A python-image-classification web application project, written in Python and served through the Flask Microframework. This Project implements the VGG16 covolutional neural network, through Keras and

Gerald Maduabuchi 19 Dec 12, 2022