Pytorch implementation of Value Iteration Networks (NIPS 2016 best paper)

Overview

VIN: Value Iteration Networks

Architecture of Value Iteration Network

A quick thank you

A few others have released amazing related work which helped inspire and improve my own implementation. It goes without saying that this release would not be nearly as good if it were not for all of the following:

Why another VIN implementation?

  1. The Pytorch VIN model in this repository is, in my opinion, more readable and closer to the original Theano implementation than others I have found (both Tensorflow and Pytorch).
  2. This is not simply an implementation of the VIN model in Pytorch, it is also a full Python implementation of the gridworld environments as used in the original MATLAB implementation.
  3. Provide a more extensible research base for others to build off of without needing to jump through the possible MATLAB paywall.

Installation

This repository requires following packages:

Use pip to install the necessary dependencies:

pip install -U -r requirements.txt 

Note that PyTorch cannot be installed directly from PyPI; refer to http://pytorch.org/ for custom installation instructions specific to your needs.

How to train

8x8 gridworld

python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128

16x16 gridworld

python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128

28x28 gridworld

python train.py --datafile dataset/gridworld_28x28.npz --imsize 28 --lr 0.002 --epochs 30 --k 36 --batch_size 128

Flags:

  • datafile: The path to the data files.
  • imsize: The size of input images. One of: [8, 16, 28]
  • lr: Learning rate with RMSProp optimizer. Recommended: [0.01, 0.005, 0.002, 0.001]
  • epochs: Number of epochs to train. Default: 30
  • k: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • l_i: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
  • l_h: Number of channels in first convolutional layer. Default: 150, described in paper.
  • l_q: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
  • batch_size: Batch size. Default: 128

How to test / visualize paths (requires training first)

8x8 gridworld

python test.py --weights trained/vin_8x8.pth --imsize 8 --k 10

16x16 gridworld

python test.py --weights trained/vin_16x16.pth --imsize 16 --k 20

28x28 gridworld

python test.py --weights trained/vin_28x28.pth --imsize 28 --k 36

To visualize the optimal and predicted paths simply pass:

--plot

Flags:

  • weights: Path to trained weights.
  • imsize: The size of input images. One of: [8, 16, 28]
  • plot: If supplied, the optimal and predicted paths will be plotted
  • k: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • l_i: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
  • l_h: Number of channels in first convolutional layer. Default: 150, described in paper.
  • l_q: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.

Results

Gridworld Sample One Sample Two
8x8
16x16
28x28

Datasets

Each data sample consists of an obstacle image and a goal image followed by the (x, y) coordinates of current state in the gridworld.

Dataset size 8x8 16x16 28x28
Train set 81337 456309 1529584
Test set 13846 77203 251755

The datasets (8x8, 16x16, and 28x28) included in this repository can be reproduced using the dataset/make_training_data.py script. Note that this script is not optimized and runs rather slowly (also uses a lot of memory :D)

Performance: Success Rate

This is the success rate from rollouts of the learned policy in the environment (taken over 5000 randomly generated domains).

Success Rate 8x8 16x16 28x28
PyTorch 99.69% 96.99% 91.07%

Performance: Test Accuracy

NOTE: This is the accuracy on test set. It is different from the table in the paper, which indicates the success rate from rollouts of the learned policy in the environment.

Test Accuracy 8x8 16x16 28x28
PyTorch 99.83% 94.84% 88.54%
Comments
  • testing accuracy fairly low

    testing accuracy fairly low

    I just tried to follow the instructions in the repo, and tested models trained but got a fairly low accuracy. I'm using pyTorch 0.1.12_1. Is there anything I should pay attention to?

    opened by xinleipan 10
  • Prebuilt Dataset Generation

    Prebuilt Dataset Generation

    Hello,

    I was wondering how you generated the prebuilt datasets that are downloaded when running download_weights_and_datasets.sh, i.e. what were the max_obs and max_obs_size parameters?

    Did you follow this file in the original repo? https://github.com/avivt/VIN/blob/master/scripts/make_data_gridworld_nips.m

    Thanks, Emilio

    opened by eparisotto 5
  • the rollout accuracy in test script is lower than the test accuracy in train script.

    the rollout accuracy in test script is lower than the test accuracy in train script.

    Hello!

    I have a little doubt.Does the rollout accuracy indicate the success rate? If so, why is it lower than the prediction accuracy? In the Aviv's implementation, the success rate of the 8x8 grid world was as high as 99.6%. Why is the success rate in your experiment relatively low?

    Thanks!

    opened by albzni 4
  • RUN ERROR

    RUN ERROR

    when I run 'python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128', it's ok,but again 'python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128' was run, an error occurred as follows: [email protected]:~/pytorch-value-iteration-networks$ python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 10 --k 20 --batch_size 128 Traceback (most recent call last): File "train.py", line 135, in config.datafile, imsize=config.imsize, train=True, transform=transform) File "/home/ni/pytorch-value-iteration-networks/dataset/dataset.py", line 22, in init self._process(file, self.train) File "/home/ni/pytorch-value-iteration-networks/dataset/dataset.py", line 58, in _process images = images.astype(np.float32) MemoryError

    opened by N-Kingsley 3
  • Problem of running the test script

    Problem of running the test script

    Hello,

    I downloaded the data with the .sh downloading script you provided, I also got an nps weights file after training. When I ran the testing command I got the following error: Traceback (most recent call last): File "/home/research/DL/VIN/pytorch-value-iteration-networks/test.py", line 158, in main(config) File "/home/research/DL/VIN/pytorch-value-iteration-networks/test.py", line 85, in main _, predictions = vin(X_in, S1_in, S2_in, config) File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 357, in call result = self.forward(*input, **kwargs) File "/home/research/DL/VIN/pytorch-value-iteration-networks/model.py", line 64, in forward return logits, self.sm(logits) File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 352, in call for hook in self._forward_pre_hooks.values(): File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 398, in getattr type(self).name, name)) AttributeError: 'Softmax' object has no attribute '_forward_pre_hooks'

    Thanks for helping!

    opened by YantianZha 3
  • Improved readability of the VIN model, in addition to minor changes

    Improved readability of the VIN model, in addition to minor changes

    My main modification is in the forward method of the model where you extract the q_out from the q values, and not repeating q = F.conv2d(...) in two places. I also made minor improvements, such as adding argparse in the dataset creation script and changing .cuda() into .to(device) in test.py.

    opened by shuishida 2
  • Inconsistent tensor sizes when starting training

    Inconsistent tensor sizes when starting training

    Hey there. I'm trying to run

    python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128
    

    But I get the following error

    Number of Train Samples: 103926
    Number of Test Samples: 17434
         Epoch | Train Loss | Train Error | Epoch Time
    Traceback (most recent call last):
      File "train.py", line 147, in <module>
        train(net, trainloader, config, criterion, optimizer, use_GPU)
      File "train.py", line 40, in train
        outputs, predictions = net(X, S1, S2, config)
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 224, in __call__
        result = self.forward(*input, **kwargs)
      File "/media/user_home2/j1k1000o/j1k/VINs/pytorch-value-iteration-networks/model.py", line 44, in forward
        q = F.conv2d(torch.cat([r, v], 1), 
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py", line 897, in cat
        return Concat.apply(dim, *iterable)
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/autograd/_functions/tensor.py", line 317, in forward
        return torch.cat(inputs, dim)
    RuntimeError: inconsistent tensor sizes at /opt/conda/conda-bld/pytorch_1502009910772/work/torch/lib/THC/generic/THCTensorMath.cu:141
    

    I've executed

    ./download_weights_and_datasets.sh
    

    as well as

    python ./dataset/make_training_data.py
    

    And I'm running it on an Ubuntu 16.04, python 3.6 and with all the requirements installed.

    Can you help me out?

    opened by juancprzs 2
  • Don't understand VIN last step

    Don't understand VIN last step

        slice_s1 = S1.long().expand(config.imsize, 1, config.l_q, q.size(0))
        slice_s1 = slice_s1.permute(3, 2, 1, 0)
        q_out = q.gather(2, slice_s1).squeeze(2)
    

    What does this 3 lines do?

    opened by QiXuanWang 1
  • KeyError: 'arr_1 is not a file in the archive'

    KeyError: 'arr_1 is not a file in the archive'

    python3 train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128 Traceback (most recent call last): File "train.py", line 135, in config.datafile, imsize=config.imsize, train=True, transform=transform) File "/home/user/pytorch/tutorials/valueiterationnetworks/pytorch-value-iteration-networks/dataset/dataset.py", line 22, in init self._process(file, self.train) File "/home/user/pytorch/tutorials/valueiterationnetworks/pytorch-value-iteration-networks/dataset/dataset.py", line 49, in _process S1 = f['arr_1'] File "/home/user/miniconda3/lib/python3.6/site-packages/numpy/lib/npyio.py", line 255, in getitem raise KeyError("%s is not a file in the archive" % key) KeyError: 'arr_1 is not a file in the archive'

    I got this error, could you please

    opened by derelearnro 1
  • Problem of running dataset/make_training_data.py script

    Problem of running dataset/make_training_data.py script

    Hi

    When I tried to run the make_training_data.py script to generate the gridworld.npz file, I got the following error:

    FileNotFoundError: [Errno 2] No such file or directory: 'dataset/gridworld_28x28.npz'
    

    And I found that line 101 should be modified as follows:

    save_path = "gridworld_{0}x{1}".format(dom_size[0], dom_size[1])
    
    opened by ruqing00 0
Owner
Kent Sommer
Software Engineer @ Toyota Research Institute (SF Bay Area)
Kent Sommer
Code for the Image similarity challenge.

ISC 2021 This repository contains code for the Image Similarity Challenge 2021. Getting started The docs subdirectory has step-by-step instructions on

Facebook Research 173 Dec 12, 2022
NL-Augmenter 🦎 β†’ 🐍 A Collaborative Repository of Natural Language Transformations

NL-Augmenter 🦎 β†’ 🐍 The NL-Augmenter is a collaborative effort intended to add transformations of datasets dealing with natural language. Transformat

684 Jan 09, 2023
Patches desktop steam to look like the new steamdeck ui.

steam_deck_ui_patch The Deck UI patch will patch the regular desktop steam to look like the brand new SteamDeck UI. This patch tool currently works on

The_IT_Dude 3 Aug 29, 2022
Includes PyTorch -> Keras model porting code for ConvNeXt family of models with fine-tuning and inference notebooks.

ConvNeXt-TF This repository provides TensorFlow / Keras implementations of different ConvNeXt [1] variants. It also provides the TensorFlow / Keras mo

Sayak Paul 87 Dec 06, 2022
Federated_learning codes used for the the paper "Evaluation of Federated Learning Aggregation Algorithms" and "A Federated Learning Aggregation Algorithm for Pervasive Computing: Evaluation and Comparison"

Federated Distance (FedDist) This is the code accompanying the Percom2021 paper "A Federated Learning Aggregation Algorithm for Pervasive Computing: E

GETALP 8 Jan 03, 2023
DeepConsensus uses gap-aware sequence transformers to correct errors in Pacific Biosciences (PacBio) Circular Consensus Sequencing (CCS) data.

DeepConsensus DeepConsensus uses gap-aware sequence transformers to correct errors in Pacific Biosciences (PacBio) Circular Consensus Sequencing (CCS)

Google 149 Dec 19, 2022
Captcha-tensorflow - Image Captcha Solving Using TensorFlow and CNN Model. Accuracy 90%+

Captcha Solving Using TensorFlow Introduction Solve captcha using TensorFlow. Learn CNN and TensorFlow by a practical project. Follow the steps, run t

Jackon Yang 869 Jan 06, 2023
Prior-Guided Multi-View 3D Head Reconstruction

Prior-Guided Head MVS This repository includes some reconstruction results of our IEEE TMM 2021 paper, Prior-Guided Multi-View 3D Head Reconstruction.

11 Aug 17, 2022
Deep Learning and Reinforcement Learning Library for Scientists and Engineers πŸ”₯

TensorLayer is a novel TensorFlow-based deep learning and reinforcement learning library designed for researchers and engineers. It provides an extens

TensorLayer Community 7.1k Dec 29, 2022
Pytorch implementation of NeurIPS 2021 paper: Geometry Processing with Neural Fields.

Geometry Processing with Neural Fields Pytorch implementation for the NeurIPS 2021 paper: Geometry Processing with Neural Fields Guandao Yang, Serge B

Guandao Yang 162 Dec 16, 2022
Predict Breast Cancer Wisconsin (Diagnostic) using Naive Bayes

Naive-Bayes Predict Breast Cancer Wisconsin (Diagnostic) using Naive Bayes Downloading Data Set Use our Breast Cancer Wisconsin Data Set Also you can

Faeze Habibi 0 Apr 06, 2022
Tgbox-bench - Simple TGBOX upload speed benchmark

TGBOX Benchmark This script will benchmark upload speed to TGBOX storage. Build

Non 1 Jan 09, 2022
Tightness-aware Evaluation Protocol for Scene Text Detection

TIoU-metric Release on 27/03/2019. This repository is built on the ICDAR 2015 evaluation code. If you propose a better metric and require further eval

Yuliang Liu 206 Nov 18, 2022
A Convolutional Transformer for Keyword Spotting

☒️ Audiomer ☒️ Audiomer: A Convolutional Transformer for Keyword Spotting [ arXiv ] [ Previous SOTA ] [ Model Architecture ] Results on SpeechCommands

49 Jan 27, 2022
Face Recognition & AI Based Smart Attendance Monitoring System.

In today’s generation, authentication is one of the biggest problems in our society. So, one of the most known techniques used for authentication is h

Sagar Saha 1 Jan 14, 2022
Keywords : Streamlit, BertTokenizer, BertForMaskedLM, Pytorch

Next Word Prediction Keywords : Streamlit, BertTokenizer, BertForMaskedLM, Pytorch 🎬 Project Demo βœ” Application is hosted on Streamlit. You can see t

Vivek7 3 Aug 26, 2022
HNN: Human (Hollywood) Neural Network

HNN: Human (Hollywood) Neural Network Learn the top 1000 actors on IMDB with your very own low cost, highly parallel, CUDAless biological neural netwo

Madhava Jay 0 Dec 21, 2021
The official implementation of "Rethink Dilated Convolution for Real-time Semantic Segmentation"

RegSeg The official implementation of "Rethink Dilated Convolution for Real-time Semantic Segmentation" Paper: arxiv D block Decoder Setup Install the

Roland 61 Dec 27, 2022
SAPIEN Manipulation Skill Benchmark

ManiSkill Benchmark SAPIEN Manipulation Skill Benchmark (abbreviated as ManiSkill, pronounced as "Many Skill") is a large-scale learning-from-demonstr

Hao Su's Lab, UCSD 107 Jan 08, 2023
Human Dynamics from Monocular Video with Dynamic Camera Movements

Human Dynamics from Monocular Video with Dynamic Camera Movements Ri Yu, Hwangpil Park and Jehee Lee Seoul National University ACM Transactions on Gra

215 Jan 01, 2023