OCRA (Object-Centric Recurrent Attention) source code

Related tags

Deep LearningOCRA
Overview

OCRA (Object-Centric Recurrent Attention) source code

Hossein Adeli and Seoyoung Ahn

Please cite this article if you find this repository useful:


  • For data generation and loading

    1. stimuli_util.ipynb includes all the codes and the instructions for how to generate the datasets for the three tasks; MultiMNIST, MultiMNIST Cluttered and MultiSVHN.
    2. loaddata.py should be updated with the location of the data files for the tasks if not the default used.
  • For training and testing the model:

    1. OCRA_demo.ipynb includes the code for building and training the model. In the first notebook cell, a hyperparameter file should be specified. Parameter files are provided here (different settings are discussed in the supplementary file)

    2. multimnist_params_10glimpse.txt and multimnist_params_3glimpse.txt set all the hyperparameters for MultiMNIST task with 10 and 3 glimpses, respectively.

    OCRA_demo-MultiMNIST_3glimpse_training.ipynb shows how to load a parameter file and train the model.

    1. multimnist_cluttered_params_7glimpse.txt and multimnist_cluttered_params_5glimpse.txt set all the hyperparameters for MultiMNIST Cluttered task with 7 and 5 glimpses, respectively.

    2. multisvhn_params.txt sets all the hyperparameters for the MultiSVHN task with 12 glimpses.

    3. This notebook also includes code for testing a trained model and also for plotting the attention windows for sample images.

    OCRA_demo-cluttered_5steps_loadtrained.ipynb shows how to load a trained model and test it on the test dataset. Example pretrained models are included in the repository under pretrained folder. Download all the pretrained models.

Image-level accuracy averaged from 5 runs

Task (Model name) Error Rate (SD)
MultiMNIST (OCRA-10glimpse) 5.08 (0.17)
Cluttered MultiMNIST (OCRA-7glimpse) 7.12 (1.05)
MultiSVHN (OCRA-12glimpse) 10.07 (0.53)

Validation losses during training

From MultiMNIST OCRA-10glimpse:

From Cluttered MultiMNIST OCRA-7glimpse

Supplementary Results:

Object-centric behavior

The opportunity to observe the object-centric behavior is bigger in the cluttered task. Since the ratio of the glimpse size to the image size is small (covering less than 4 percent of the image), the model needs to optimally move and select the objects to accurately recognize them. Also reducing the number of glimpses has a similar effect, (we experimented with 3 and 5) forcing the model to leverage its object-centric representation to find the objects without being distracted by the noise segments. We include many more examples of the model behavior with both 3 and 5 glimpses to show this behavior.

MultiMNIST Cluttered task with 5 glimpses






MultiMNIST Cluttered task with 3 glimpses





The Street View House Numbers Dataset

We train the model to "read" the digits from left to right by having the order of the predicted sequence match the ground truth from left to right. We allow the model to make 12 glimpses, with the first two not being constrained and the capsule length from every following two glimpses will be read out for the output digit (e.g. the capsule lengths from the 3rd and 4th glimpses are read out to predict digit number 1; the left-most digit and so on). Below are sample behaviors from our model.

The top five rows show the original images, and the bottom five rows show the reconstructions

SVHN_gif

The generation of sample images across 12 glimpses

SVHN_gif

The generatin in a gif fromat

SVHN_gif

The model learns to detect and reconstruct objects. The model achieved ~2.5 percent error rate on recognizing individual digits and ~10 percent error in recognizing whole sequences still lagging SOTA performance on this measure. We believe this to be strongly related to our small two-layer convolutional backbone and we expect to get better results with a deeper one, which we plan to explore next. However, the model shows reasonable attention behavior in performing this task.

Below shows the model's read and write attention behavior as it reads and reconstructs one image.

Herea are a few sample mistakes from our model:

SVHN_error1
ground truth [ 1, 10, 10, 10, 10]
prediction [ 0, 10, 10, 10, 10]

SVHN_error2
ground truth [ 2, 8, 10, 10, 10]
prediction [ 2, 9, 10, 10, 10]

SVHN_error3
ground truth [ 1, 2, 9, 10, 10]
prediction [ 1, 10, 10, 10, 10]

SVHN_error4
ground truth [ 5, 1, 10, 10, 10]
prediction [ 5, 7, 10, 10, 10]


Some MNIST cluttered results

Testing the model on MNIST cluttered dataset with three time steps


Code references:

  1. XifengGuo/CapsNet-Pytorch
  2. kamenbliznashki/generative_models
  3. pitsios-s/SVHN
Owner
Hossein Adeli
Hossein Adeli
An end-to-end library for editing and rendering motion of 3D characters with deep learning [SIGGRAPH 2020]

Deep-motion-editing This library provides fundamental and advanced functions to work with 3D character animation in deep learning with Pytorch. The co

1.2k Dec 29, 2022
Self-Supervised Pre-Training for Transformer-Based Person Re-Identification

Self-Supervised Pre-Training for Transformer-Based Person Re-Identification [pdf] The official repository for Self-Supervised Pre-Training for Transfo

Hao Luo 116 Jan 04, 2023
Small-bets - Ergodic Experiment With Python

Ergodic Experiment Based on this video. Run this experiment with this command: p

Michael Brant 3 Jan 11, 2022
Torchlight2 lan game server tool - A message forwarding tool for Torchlight 2 lan game

Torchlight 2 Lan Game Server Tool A message forwarding tool for Torchlight 2 lan

Huaijun Jiang 3 Nov 01, 2022
Use CLIP to represent video for Retrieval Task

A Straightforward Framework For Video Retrieval Using CLIP This repository contains the basic code for feature extraction and replication of results.

Jesus Andres Portillo Quintero 54 Dec 22, 2022
C3d-pytorch - Pytorch porting of C3D network, with Sports1M weights

C3D for pytorch This is a pytorch porting of the network presented in the paper Learning Spatiotemporal Features with 3D Convolutional Networks How to

Davide Abati 311 Jan 06, 2023
GARCH and Multivariate LSTM forecasting models for Bitcoin realized volatility with potential applications in crypto options trading, hedging, portfolio management, and risk management

Bitcoin Realized Volatility Forecasting with GARCH and Multivariate LSTM Author: Chi Bui This Repository Repository Directory ├── README.md

Chi Bui 113 Dec 29, 2022
On-device speech-to-intent engine powered by deep learning

Rhino Made in Vancouver, Canada by Picovoice Rhino is Picovoice's Speech-to-Intent engine. It directly infers intent from spoken commands within a giv

Picovoice 510 Dec 30, 2022
Learning-based agent for Google Research Football

TiKick 1.Introduction Learning-based agent for Google Research Football Code accompanying the paper "TiKick: Towards Playing Multi-agent Football Full

Tsinghua AI Research Team for Reinforcement Learning 90 Dec 26, 2022
Implementation of C-RNN-GAN.

Implementation of C-RNN-GAN. Publication: Title: C-RNN-GAN: Continuous recurrent neural networks with adversarial training Information: http://mogren.

Olof Mogren 427 Dec 25, 2022
Make differentially private training of transformers easy for everyone

private-transformers This codebase facilitates fast experimentation of differentially private training of Hugging Face transformers. What is this? Why

Xuechen Li 73 Dec 28, 2022
Code for Universal Semi-Supervised Semantic Segmentation models paper accepted in ICCV 2019

USSS_ICCV19 Code for Universal Semi Supervised Semantic Segmentation accepted to ICCV 2019. Full Paper available at https://arxiv.org/abs/1811.10323.

Tarun K 68 Nov 24, 2022
💊 A 3D Generative Model for Structure-Based Drug Design (NeurIPS 2021)

A 3D Generative Model for Structure-Based Drug Design Coming soon... Citation @inproceedings{luo2021sbdd, title={A 3D Generative Model for Structu

Shitong Luo 118 Jan 05, 2023
Face Recognition and Emotion Detector Device

Face Recognition and Emotion Detector Device Orange PI 1 Python 3.10.0 + Django 3.2.9 Project's file explanation Django manage.py Django commands hand

BootyAss 2 Dec 21, 2021
Soomvaar is the repo which 🏩 contains different collection of 👨‍💻🚀code in Python and 💫✨Machine 👬🏼 learning algorithms📗📕 that is made during 📃 my practice and learning of ML and Python✨💥

Soomvaar 📌 Introduction Soomvaar is the collection of various codes implement in machine learning and machine learning algorithms with python on coll

Felix-Ayush 42 Dec 30, 2022
Supplementary materials to "Spin-optomechanical quantum interface enabled by an ultrasmall mechanical and optical mode volume cavity" by H. Raniwala, S. Krastanov, M. Eichenfield, and D. R. Englund, 2022

Supplementary materials to "Spin-optomechanical quantum interface enabled by an ultrasmall mechanical and optical mode volume cavity" by H. Raniwala,

Stefan Krastanov 1 Jan 17, 2022
A Python implementation of the Locality Preserving Matching (LPM) method for pruning outliers in image matching.

LPM_Python A Python implementation of the Locality Preserving Matching (LPM) method for pruning outliers in image matching. The code is established ac

AoxiangFan 11 Nov 07, 2022
Source code for the ACL-IJCNLP 2021 paper entitled "T-DNA: Taming Pre-trained Language Models with N-gram Representations for Low-Resource Domain Adaptation" by Shizhe Diao et al.

T-DNA Source code for the ACL-IJCNLP 2021 paper entitled Taming Pre-trained Language Models with N-gram Representations for Low-Resource Domain Adapta

shizhediao 17 Dec 22, 2022
This is a demo app to be used in the video streaming applications

MoViDNN: A Mobile Platform for Evaluating Video Quality Enhancement with Deep Neural Networks MoViDNN is an Android application that can be used to ev

ATHENA Christian Doppler (CD) Laboratory 7 Jul 21, 2022
Tiny-NewsRec: Efficient and Effective PLM-based News Recommendation

Tiny-NewsRec The source codes for our paper "Tiny-NewsRec: Efficient and Effective PLM-based News Recommendation". Requirements PyTorch == 1.6.0 Tensor

Yang Yu 3 Dec 07, 2022