A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

Overview

CapsNet-Tensorflow

Contributions welcome License Gitter

A Tensorflow implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

capsVSneuron

Notes:

  1. The current version supports MNIST and Fashion-MNIST datasets. The current test accuracy for MNIST is 99.64%, and Fashion-MNIST 90.60%, see details in the Results section
  2. See dist_version for multi-GPU support
  3. Here(ηŸ₯乎) is an article explaining my understanding of the paper. It may be helpful in understanding the code.

Important:

If you need to apply CapsNet model to your own datasets or build up a new model with the basic block of CapsNet, please follow my new project CapsLayer, which is an advanced library for capsule theory, aiming to integrate capsule-relevant technologies, provide relevant analysis tools, develop related application examples, and promote the development of capsule theory. For example, you can use capsule layer block in your code easily with the API capsLayer.layers.fully_connected and capsLayer.layers.conv2d

Requirements

  • Python
  • NumPy
  • Tensorflow>=1.3
  • tqdm (for displaying training progress info)
  • scipy (for saving images)

Usage

Step 1. Download this repository with git or click the download ZIP button.

$ git clone https://github.com/naturomics/CapsNet-Tensorflow.git
$ cd CapsNet-Tensorflow

Step 2. Download MNIST or Fashion-MNIST dataset. In this step, you have two choices:

  • a) Automatic downloading with download_data.py script
$ python download_data.py   (for mnist dataset)
$ python download_data.py --dataset fashion-mnist --save_to data/fashion-mnist (for fashion-mnist dataset)
  • b) Manual downloading with wget or other tools, move and extract dataset into data/mnist or data/fashion-mnist directory, for example:
$ mkdir -p data/mnist
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
$ wget -c -P data/mnist http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
$ gunzip data/mnist/*.gz

Step 3. Start the training(Using the MNIST dataset by default):

$ python main.py
$ # or training for fashion-mnist dataset
$ python main.py --dataset fashion-mnist
$ # If you need to monitor the training process, open tensorboard with this command
$ tensorboard --logdir=logdir
$ # or use `tail` command on linux system
$ tail -f results/val_acc.csv

Step 4. Calculate test accuracy

$ python main.py --is_training=False
$ # for fashion-mnist dataset
$ python main.py --dataset fashion-mnist --is_training=False

Note: The default parameters of batch size is 128, and epoch 50. You may need to modify the config.py file or use command line parameters to suit your case, e.g. set batch size to 64 and do once test summary every 200 steps: python main.py --test_sum_freq=200 --batch_size=48

Results

The pictures here are plotted by tensorboard and my tool plot_acc.R

  • training loss

total_loss margin_loss reconstruction_loss

Here are the models I trained and my talk and something else:

Baidu Netdisk(password:ahjs)

  • The best val error(using reconstruction)
Routing iteration 1 3 4
val error 0.36 0.36 0.41
Paper 0.29 0.25 -

test_acc

My simple comments for capsule

  1. A new version neural unit(vector in vector out, not scalar in scalar out)
  2. The routing algorithm is similar to attention mechanism
  3. Anyway, a great potential work, a lot to be built upon

My weChat:

my_wechat

Reference

Owner
Huadong Liao
Explore Nature from an Omics Perspective
Huadong Liao
Analysing poker data from home games with friends

Poker Game Analysis Analysing poker data from home games with friends. Not a lot of data is collected, so this project is primarily focussed on descri

Stavros Karmaniolos 1 Oct 15, 2022
Continuum Learning with GEM: Gradient Episodic Memory

Gradient Episodic Memory for Continual Learning Source code for the paper: @inproceedings{GradientEpisodicMemory, title={Gradient Episodic Memory

Facebook Research 360 Dec 27, 2022
Post-Training Quantization for Vision transformers.

PTQ4ViT Post-Training Quantization Framework for Vision Transformers. We use the twin uniform quantization method to reduce the quantization error on

Zhihang Yuan 61 Dec 28, 2022
Code for LIGA-Stereo Detector, ICCV'21

LIGA-Stereo Introduction This is the official implementation of the paper LIGA-Stereo: Learning LiDAR Geometry Aware Representations for Stereo-based

Xiaoyang Guo 75 Dec 09, 2022
Pretrained Cost Model for Distributed Constraint Optimization Problems

Pretrained Cost Model for Distributed Constraint Optimization Problems Requirements PyTorch 1.9.0 PyTorch Geometric 1.7.1 Directory structure baseline

2 Aug 28, 2022
StarGAN v2 - Official PyTorch Implementation (CVPR 2020)

StarGAN v2 - Official PyTorch Implementation StarGAN v2: Diverse Image Synthesis for Multiple Domains Yunjey Choi*, Youngjung Uh*, Jaejun Yoo*, Jung-W

Clova AI Research 3.1k Jan 09, 2023
DiAne is a smart fuzzer for IoT devices

Diane Diane is a fuzzer for IoT devices. Diane works by identifying fuzzing triggers in the IoT companion apps to produce valid yet under-constrained

seclab 28 Jan 04, 2023
This is an example implementation of the paper "Cross Domain Robot Imitation with Invariant Representation".

IR-GAIL This is an example implementation of the paper "Cross Domain Robot Imitation with Invariant Representation". Dependency The experiments are de

Zhao-Heng Yin 1 Jul 14, 2022
ShapeGlot: Learning Language for Shape Differentiation

ShapeGlot: Learning Language for Shape Differentiation Created by Panos Achlioptas, Judy Fan, Robert X.D. Hawkins, Noah D. Goodman, Leonidas J. Guibas

Panos 32 Dec 23, 2022
PyTorch implementation of the paper: Label Noise Transition Matrix Estimation for Tasks with Lower-Quality Features

Label Noise Transition Matrix Estimation for Tasks with Lower-Quality Features Estimate the noise transition matrix with f-mutual information. This co

<a href=[email protected]"> 1 Jun 05, 2022
links and status of cool gradio demos

awesome-demos This is a list of some wonderful demos & applications built with Gradio. Here's how to contribute yours! πŸ–ŠοΈ Natural language processing

Gradio 96 Dec 30, 2022
Pytorch library for seismic data augmentation

Pytorch library for seismic data augmentation

Artemii Novoselov 27 Nov 22, 2022
A new version of the CIDACS-RL linkage tool suitable to a cluster computing environment.

Fully Distributed CIDACS-RL The CIDACS-RL is a brazillian record linkage tool suitable to integrate large amount of data with high accuracy. However,

Robespierre Pita 5 Nov 04, 2022
Code for "Learning to Regrasp by Learning to Place"

Learning2Regrasp Learning to Regrasp by Learning to Place, CoRL 2021. Introduction We propose a point-cloud-based system for robots to predict a seque

Shuo Cheng (ζˆη‘•) 18 Aug 27, 2022
Geometric Vector Perceptron --- a rotation-equivariant GNN for learning from biomolecular structure

Geometric Vector Perceptron Code to accompany Learning from Protein Structure with Geometric Vector Perceptrons by B Jing, S Eismann, P Suriana, RJL T

Dror Lab 85 Dec 29, 2022
Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease

Heart_Disease_Classification Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease Dataset

Ashish 1 Jan 30, 2022
Out-of-distribution detection using the pNML regret. NeurIPS2021

OOD Detection Load conda environment conda env create -f environment.yml or install requirements: while read requirement; do conda install --yes $requ

Koby Bibas 23 Dec 02, 2022
Neural network for stock price prediction

neural_network_for_stock_price_prediction Neural networks for stock price predic

2 Feb 04, 2022
Fader Networks: Manipulating Images by Sliding Attributes - NIPS 2017

FaderNetworks PyTorch implementation of Fader Networks (NIPS 2017). Fader Networks can generate different realistic versions of images by modifying at

Facebook Research 753 Dec 23, 2022
Born-Infeld (BI) for AI: Energy-Conserving Descent (ECD) for Optimization

Born-Infeld (BI) for AI: Energy-Conserving Descent (ECD) for Optimization This repository contains the code for the BBI optimizer, introduced in the p

G. Bruno De Luca 5 Sep 06, 2022