A Tensorfflow implementation of Attend, Infer, Repeat

Overview

Attend, Infer, Repeat: Fast Scene Understanding with Generative Models

This is an unofficial Tensorflow implementation of Attend, Infear, Repeat (AIR), as presented in the following paper: S. M. Ali Eslami et. al., Attend, Infer, Repeat: Fast Scene Understanding with Generative Models.

  • Author (of the implementation): Adam Kosiorek, Oxford Robotics Institue, University of Oxford
  • Email: adamk(at)robots.ox.ac.uk
  • Webpage: http://akosiorek.github.io/

I describe the implementation and the issues I run into while working on it in this blog post.

Installation

Install Tensorflow v1.1.0rc1, Sonnet v1.1 and the following dependencies (using pip install -r requirements.txt (preferred) or pip install [package]):

  • matplotlib==1.5.3
  • numpy==1.12.1
  • attrdict==2.0.0
  • scipy==0.18.1

Sample Results

AIR learns to reconstruct objects by painting them one by one in a blank canvas. The below figure comes from a model trained for 175k iterations; the maximum number of steps is set to 3, but there are never more than 2 objects. The first row shows the input images, rows 2-4 are reconstructions at steps 1, 2 and 3 (with marked location of the attention glimpse in red, if it exists). Rows 4-7 are the reconstructed image crops, and above each crop is the probability of executing 1, 2 or 3 steps. If the reconstructed crop is black and there is "0 with ..." written above it, it means that this step was not used.

AIR results

Data

Run ./scripts/create_dataset.sh The script creates train and validation datasets of multi-digit MNIST.

Training

Run ./scripts/train_multi_mnist.sh The training script will run for 300k iteratios and will save model checkpoints and training progress figures every 10k iterations in results/multi_mnist. Tensorflow summaries are also stored in the same folder and Tensorboard can be used for monitoring.

The model seems to be very sensitive to initialisation. It might be necessary to run training multiple times before achieving count step accuracy close to the one reported in the paper.

Experimentation

The jupyter notebook available at attend_infer_repeat/experiment.ipynb can be used for experimentation.

Citation

If you find this repo useful in your research, please consider citing the original paper:

@incollection{Eslami2016,
    title = {Attend, Infer, Repeat: Fast Scene Understanding with Generative Models},
    author = {Eslami, S. M. Ali and Heess, Nicolas and Weber, Theophane and Tassa, Yuval and Szepesvari, David and kavukcuoglu, koray and Hinton, Geoffrey E},
    booktitle = {Advances in Neural Information Processing Systems 29},
    editor = {D. D. Lee and M. Sugiyama and U. V. Luxburg and I. Guyon and R. Garnett},
    pages = {3225--3233},
    year = {2016},
    publisher = {Curran Associates, Inc.},
    url = {http://papers.nips.cc/paper/6230-attend-infer-repeat-fast-scene-understanding-with-generative-models.pdf}
}

License

This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with this program. If not, see http://www.gnu.org/licenses/.

Release Notes

Version 1.0

  • Original unofficial implementation; contains the multi-digit MNIST experiment.
Owner
Adam Kosiorek
I'm a PhD student at the Oxford Robotics Institute. I work on Machine Learning for perception - I'm looking into external memory and attention for RNNs.
Adam Kosiorek
Repo for flood prediction using LSTMs and HAND

Abstract Every year, floods cause billions of dollars’ worth of damages to life, crops, and property. With a proper early flood warning system in plac

1 Oct 27, 2021
The tl;dr on a few notable transformer/language model papers + other papers (alignment, memorization, etc).

The tl;dr on a few notable transformer/language model papers + other papers (alignment, memorization, etc).

Will Thompson 166 Jan 04, 2023
A font family with a great monospaced variant for programmers.

Fantasque Sans Mono A programming font, designed with functionality in mind, and with some wibbly-wobbly handwriting-like fuzziness that makes it unas

Jany Belluz 6.3k Jan 08, 2023
Official pytorch implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion"

DSPoint Official implementation of "DSPoint: Dual-scale Point Cloud Recognition with High-frequency Fusion". Paper link: https://arxiv.org/abs/2111.10

Ziyao Zeng 14 Feb 26, 2022
Deep Learning Specialization by Andrew Ng, deeplearning.ai.

Deep Learning Specialization on Coursera Master Deep Learning, and Break into AI This is my personal projects for the course. The course covers deep l

Engen 1.5k Jan 07, 2023
A project for developing transformer-based models for clinical relation extraction

Clinical Relation Extration with Transformers Aim This package is developed for researchers easily to use state-of-the-art transformers models for ext

uf-hobi-informatics-lab 101 Dec 19, 2022
EdiBERT, a generative model for image editing

EdiBERT, a generative model for image editing EdiBERT is a generative model based on a bi-directional transformer, suited for image manipulation. The

16 Dec 07, 2022
Christmas face app for Decathlon xmas coding party!

Christmas Face Application Use this library to create the perfect picture for your christmas cards! Done by Hasib Zunair, Guillaume Brassard and Samue

Hasib Zunair 4 Dec 20, 2021
MLJetReconstruction - using machine learning to reconstruct jets for CMS

MLJetReconstruction - using machine learning to reconstruct jets for CMS The C++ data extraction code used here was based heavily on that foundv here.

ALPhA Davidson 0 Nov 17, 2021
This repository contains all source code, pre-trained models related to the paper "An Empirical Study on GANs with Margin Cosine Loss and Relativistic Discriminator"

An Empirical Study on GANs with Margin Cosine Loss and Relativistic Discriminator This is a Pytorch implementation for the paper "An Empirical Study o

Cuong Nguyen 3 Nov 15, 2021
Nsdf: A mesh SDF with just some code we can directly paste into our raymarcher

nsdf Representing SDFs of arbitrary meshes has been a bit tricky so far. Express

Jan Ivanecky 5 Feb 18, 2022
Empirical Study of Transformers for Source Code & A Simple Approach for Handling Out-of-Vocabulary Identifiers in Deep Learning for Source Code

Transformers for variable misuse, function naming and code completion tasks The official PyTorch implementation of: Empirical Study of Transformers fo

Bayesian Methods Research Group 56 Nov 15, 2022
Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion Models

Label-Efficient Semantic Segmentation with Diffusion Models Official implementation of the paper Label-Efficient Semantic Segmentation with Diffusion

Yandex Research 355 Jan 06, 2023
Programming with Neural Surrogates of Programs

Programming with Neural Surrogates of Programs

0 Dec 12, 2021
This is the official Pytorch implementation of "Lung Segmentation from Chest X-rays using Variational Data Imputation", Raghavendra Selvan et al. 2020

README This is the official Pytorch implementation of "Lung Segmentation from Chest X-rays using Variational Data Imputation", Raghavendra Selvan et a

Raghav 42 Dec 15, 2022
“英特尔创新大师杯”深度学习挑战赛 赛道3:CCKS2021中文NLP地址相关性任务

基于 bert4keras 的一个baseline 不作任何 数据trick 单模 线上 最高可到 0.7891 # 基础 版 train.py 0.7769 # transformer 各层 cls concat 明神的trick https://xv44586.git

孙永松 7 Dec 28, 2021
Video Autoencoder: self-supervised disentanglement of 3D structure and motion

Video Autoencoder: self-supervised disentanglement of 3D structure and motion This repository contains the code (in PyTorch) for the model introduced

157 Dec 22, 2022
Official Implementation of SWAD (NeurIPS 2021)

SWAD: Domain Generalization by Seeking Flat Minima (NeurIPS'21) Official PyTorch implementation of SWAD: Domain Generalization by Seeking Flat Minima.

Junbum Cha 97 Dec 20, 2022
Junction Tree Variational Autoencoder for Molecular Graph Generation (ICML 2018)

Junction Tree Variational Autoencoder for Molecular Graph Generation Official implementation of our Junction Tree Variational Autoencoder https://arxi

Wengong Jin 418 Jan 07, 2023
PyTorch implementation of CVPR'18 - Perturbative Neural Networks

This is an attempt to reproduce results in Perturbative Neural Networks paper. See original repo for details.

Michael Klachko 57 May 14, 2021