PyTorch implementation of Value Iteration Networks (VIN): Clean, Simple and Modular. Visualization in Visdom.

Overview

VIN: Value Iteration Networks

This is an implementation of Value Iteration Networks (VIN) in PyTorch to reproduce the results.(TensorFlow version)

Architecture of Value Iteration Network

Key idea

  • A fully differentiable neural network with a 'planning' sub-module.
  • Value Iteration = Conv Layer + Channel-wise Max Pooling
  • Generalize better than reactive policies for new, unseen tasks.

Learned Reward Image and Its Value Images for each VI Iteration

Visualization Grid world Reward Image Value Images
8x8
16x16
28x28

Dependencies

This repository requires following packages:

  • Python >= 3.6
  • Numpy >= 1.12.1
  • PyTorch >= 0.1.10
  • SciPy >= 0.19.0
  • visdom >= 0.1

Datasets

Each data sample consists of (x, y) coordinates of current state in grid world, followed by an obstacle image and a goal image.

Dataset size 8x8 16x16 28x28
Train set 77760 776440 4510695
Test set 12960 129440 751905

Running Experiment: Training

Grid world 8x8

python run.py --datafile data/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128

Grid world 16x16

python run.py --datafile data/gridworld_16x16.npz --imsize 16 --lr 0.008 --epochs 30 --k 20 --batch_size 128

Grid world 28x28

python run.py --datafile data/gridworld_28x28.npz --imsize 28 --lr 0.003 --epochs 30 --k 36 --batch_size 128

Flags:

  • datafile: The path to the data files.
  • imsize: The size of input images. From: [8, 16, 28]
  • lr: Learning rate with RMSProp optimizer. Recommended: [0.01, 0.005, 0.002, 0.001]
  • epochs: Number of epochs to train. Default: 30
  • k: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • ch_i: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
  • ch_h: Number of channels in first convolutional layer. Default: 150, described in paper.
  • ch_q: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
  • batch_size: Batch size. Default: 128

Visualization with Visdom

We shall visualize the learned reward image and its corresponding value images for each VI iteration by using visdom.

Firstly start the server

python -m visdom.server

Open Visdom in browser in http://localhost:8097

Then run following to visualize learn reward and value images.

python vis.py --datafile learned_rewards_values_28x28.npz

NOTE: If you would like to produce GIF animation of value images on your own, the following command might be useful.

convert -delay 20 -loop 0 *.png value_function.gif

Benchmarks

GPU: TITAN X

Performance: Test Accuracy

NOTE: This is the accuracy on test set. It is different from the table in the paper, which indicates the success rate from rollouts of the learned policy in the environment.

Test Accuracy 8x8 16x16 28x28
PyTorch 99.16% 92.44% 88.20%
TensorFlow 99.03% 90.2% 82%

Speed with GPU

Speed per epoch 8x8 16x16 28x28
PyTorch 3s 15s 100s
TensorFlow 4s 25s 165s

Frequently Asked Questions

  • Q: How to get reward image from observation ?

    • A: Observation image has 2 channels. First channel is obstacle image (0: free, 1: obstacle). Second channel is goal image (0: free, 10: goal). For example, in 8x8 grid world, the shape of an input tensor with batch size 128 is [128, 2, 8, 8]. Then it is fed into a convolutional layer with [3, 3] filter and 150 feature maps, followed by another convolutional layer with [3, 3] filter and 1 feature map. The shape of the output tensor is [128, 1, 8, 8]. This is the reward image.
  • Q: What is exactly transition model, and how to obtain value image by VI-module from reward image ?

    • A: Let us assume batch size is 128 under 8x8 grid world. Once we obtain the reward image with shape [128, 1, 8, 8], we do convolutional layer for q layers in VI module. The [3, 3] filter represents the transition probabilities. There is a set of 10 filters, each for generating a feature map in q layers. Each feature map corresponds to an "action". Note that this is larger than real available actions which is only 8. Then we do a channel-wise Max Pooling to obtain the value image with shape [128, 1, 8, 8]. Finally we stack this value image with reward image for a new VI iteration.

References

Further Readings

Owner
Xingdong Zuo
AI in well-being is my dream. Neural networks need to understand the world causally.
Xingdong Zuo
This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in Eurographics 2021

Deep-Detail-Enhancement-for-Any-Garment Introduction This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in

40 Dec 13, 2022
Res2Net for Instance segmentation and Object detection using MaskRCNN

Res2Net for Instance segmentation and Object detection using MaskRCNN Since the MaskRCNN-benchmark of facebook is deprecated, we suggest to use our mm

Res2Net Applications 55 Oct 30, 2022
Multi-Scale Progressive Fusion Network for Single Image Deraining

Multi-Scale Progressive Fusion Network for Single Image Deraining (MSPFN) This is an implementation of the MSPFN model proposed in the paper (Multi-Sc

Kuijiang 128 Nov 21, 2022
Multi-query Video Retreival

Multi-query Video Retreival

Princeton Visual AI Lab 17 Nov 22, 2022
This Artificial Intelligence program can take a black and white/grayscale image and generate a realistic or plausible colorized version of the same picture.

Colorizer The point of this project is to write a program capable of taking a black and white / grayscale image, and generating a realistic or plausib

Maitri Shah 1 Jan 06, 2022
Code for ECCV 2020 paper "Contacts and Human Dynamics from Monocular Video".

Contact and Human Dynamics from Monocular Video This is the official implementation for the ECCV 2020 spotlight paper by Davis Rempe, Leonidas J. Guib

Davis Rempe 207 Jan 05, 2023
Easy way to add GoogleMaps to Flask applications. maintainer: @getcake

Flask Google Maps Easy to use Google Maps in your Flask application requires Jinja Flask A google api key get here Contribute To contribute with the p

Flask Extensions 611 Dec 05, 2022
Semi-supervised Video Deraining with Dynamical Rain Generator (CVPR, 2021, Pytorch)

S2VD Semi-supervised Video Deraining with Dynamical Rain Generator (CVPR, 2021) Requirements and Dependencies Ubuntu 16.04, cuda 10.0 Python 3.6.10, P

Zongsheng Yue 53 Nov 23, 2022
Dynamic Graph Event Detection

DyGED Dynamic Graph Event Detection Get Started pip install -r requirements.txt TODO Paper link to arxiv, and how to cite. Twitter Weather dataset tra

Mert Koşan 3 May 09, 2022
[ICML 2021] Break-It-Fix-It: Learning to Repair Programs from Unlabeled Data

Break-It-Fix-It: Learning to Repair Programs from Unlabeled Data This repo provides the source code & data of our paper: Break-It-Fix-It: Unsupervised

Michihiro Yasunaga 86 Nov 30, 2022
Sarus implementation of classical ML models. The models are implemented using the Keras API of tensorflow 2. Vizualization are implemented and can be seen in tensorboard.

Sarus published models Sarus implementation of classical ML models. The models are implemented using the Keras API of tensorflow 2. Vizualization are

Sarus Technologies 39 Aug 19, 2022
subpixel: A subpixel convnet for super resolution with Tensorflow

subpixel: A subpixel convolutional neural network implementation with Tensorflow Left: input images / Right: output images with 4x super-resolution af

Atrium LTS 2.1k Dec 23, 2022
A tutorial on training a DarkNet YOLOv4 model for the CrowdHuman dataset

YOLOv4 CrowdHuman Tutorial This is a tutorial demonstrating how to train a YOLOv4 people detector using Darknet and the CrowdHuman dataset. Table of c

JK Jung 118 Nov 10, 2022
PyTorch implementation for ComboGAN

ComboGAN This is our ongoing PyTorch implementation for ComboGAN. Code was written by Asha Anoosheh (built upon CycleGAN) [ComboGAN Paper] If you use

Asha Anoosheh 139 Dec 20, 2022
Aggragrating Nested Transformer Official Jax Implementation

NesT is a simple method, which aggragrates nested local transformers on image blocks. The idea makes vision transformers attain better accuracy, data efficiency, and convergence on the ImageNet bench

Google Research 169 Dec 20, 2022
The source codes for ACL 2021 paper 'BoB: BERT Over BERT for Training Persona-based Dialogue Models from Limited Personalized Data'

BoB: BERT Over BERT for Training Persona-based Dialogue Models from Limited Personalized Data This repository provides the implementation details for

124 Dec 27, 2022
TAug :: Time Series Data Augmentation using Deep Generative Models

TAug :: Time Series Data Augmentation using Deep Generative Models Note!!! The package is under development so be careful for using in production! Fea

35 Dec 06, 2022
This repo is for segmentation of T2 hyp regions in gliomas.

T2-Hyp-Segmentor This repo is for segmentation of T2 hyp regions in gliomas. By downloading the model from here you can use it to segment your T2w ima

1 Jan 18, 2022
scikit-learn: machine learning in Python

scikit-learn is a Python module for machine learning built on top of SciPy and is distributed under the 3-Clause BSD license. The project was started

scikit-learn 52.5k Jan 08, 2023
Azua - build AI algorithms to aid efficient decision-making with minimum data requirements.

Project Azua 0. Overview Many modern AI algorithms are known to be data-hungry, whereas human decision-making is much more efficient. The human can re

Microsoft 197 Jan 06, 2023