Pytorch implementation of Value Iteration Networks (NIPS 2016 best paper)

Overview

VIN: Value Iteration Networks

Architecture of Value Iteration Network

A quick thank you

A few others have released amazing related work which helped inspire and improve my own implementation. It goes without saying that this release would not be nearly as good if it were not for all of the following:

Why another VIN implementation?

  1. The Pytorch VIN model in this repository is, in my opinion, more readable and closer to the original Theano implementation than others I have found (both Tensorflow and Pytorch).
  2. This is not simply an implementation of the VIN model in Pytorch, it is also a full Python implementation of the gridworld environments as used in the original MATLAB implementation.
  3. Provide a more extensible research base for others to build off of without needing to jump through the possible MATLAB paywall.

Installation

This repository requires following packages:

Use pip to install the necessary dependencies:

pip install -U -r requirements.txt 

Note that PyTorch cannot be installed directly from PyPI; refer to http://pytorch.org/ for custom installation instructions specific to your needs.

How to train

8x8 gridworld

python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128

16x16 gridworld

python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128

28x28 gridworld

python train.py --datafile dataset/gridworld_28x28.npz --imsize 28 --lr 0.002 --epochs 30 --k 36 --batch_size 128

Flags:

  • datafile: The path to the data files.
  • imsize: The size of input images. One of: [8, 16, 28]
  • lr: Learning rate with RMSProp optimizer. Recommended: [0.01, 0.005, 0.002, 0.001]
  • epochs: Number of epochs to train. Default: 30
  • k: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • l_i: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
  • l_h: Number of channels in first convolutional layer. Default: 150, described in paper.
  • l_q: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
  • batch_size: Batch size. Default: 128

How to test / visualize paths (requires training first)

8x8 gridworld

python test.py --weights trained/vin_8x8.pth --imsize 8 --k 10

16x16 gridworld

python test.py --weights trained/vin_16x16.pth --imsize 16 --k 20

28x28 gridworld

python test.py --weights trained/vin_28x28.pth --imsize 28 --k 36

To visualize the optimal and predicted paths simply pass:

--plot

Flags:

  • weights: Path to trained weights.
  • imsize: The size of input images. One of: [8, 16, 28]
  • plot: If supplied, the optimal and predicted paths will be plotted
  • k: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]
  • l_i: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.
  • l_h: Number of channels in first convolutional layer. Default: 150, described in paper.
  • l_q: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.

Results

Gridworld Sample One Sample Two
8x8
16x16
28x28

Datasets

Each data sample consists of an obstacle image and a goal image followed by the (x, y) coordinates of current state in the gridworld.

Dataset size 8x8 16x16 28x28
Train set 81337 456309 1529584
Test set 13846 77203 251755

The datasets (8x8, 16x16, and 28x28) included in this repository can be reproduced using the dataset/make_training_data.py script. Note that this script is not optimized and runs rather slowly (also uses a lot of memory :D)

Performance: Success Rate

This is the success rate from rollouts of the learned policy in the environment (taken over 5000 randomly generated domains).

Success Rate 8x8 16x16 28x28
PyTorch 99.69% 96.99% 91.07%

Performance: Test Accuracy

NOTE: This is the accuracy on test set. It is different from the table in the paper, which indicates the success rate from rollouts of the learned policy in the environment.

Test Accuracy 8x8 16x16 28x28
PyTorch 99.83% 94.84% 88.54%
Comments
  • testing accuracy fairly low

    testing accuracy fairly low

    I just tried to follow the instructions in the repo, and tested models trained but got a fairly low accuracy. I'm using pyTorch 0.1.12_1. Is there anything I should pay attention to?

    opened by xinleipan 10
  • Prebuilt Dataset Generation

    Prebuilt Dataset Generation

    Hello,

    I was wondering how you generated the prebuilt datasets that are downloaded when running download_weights_and_datasets.sh, i.e. what were the max_obs and max_obs_size parameters?

    Did you follow this file in the original repo? https://github.com/avivt/VIN/blob/master/scripts/make_data_gridworld_nips.m

    Thanks, Emilio

    opened by eparisotto 5
  • the rollout accuracy in test script is lower than the test accuracy in train script.

    the rollout accuracy in test script is lower than the test accuracy in train script.

    Hello!

    I have a little doubt.Does the rollout accuracy indicate the success rate? If so, why is it lower than the prediction accuracy? In the Aviv's implementation, the success rate of the 8x8 grid world was as high as 99.6%. Why is the success rate in your experiment relatively low?

    Thanks!

    opened by albzni 4
  • RUN ERROR

    RUN ERROR

    when I run 'python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128', it's ok,but again 'python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128' was run, an error occurred as follows: [email protected]:~/pytorch-value-iteration-networks$ python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 10 --k 20 --batch_size 128 Traceback (most recent call last): File "train.py", line 135, in config.datafile, imsize=config.imsize, train=True, transform=transform) File "/home/ni/pytorch-value-iteration-networks/dataset/dataset.py", line 22, in init self._process(file, self.train) File "/home/ni/pytorch-value-iteration-networks/dataset/dataset.py", line 58, in _process images = images.astype(np.float32) MemoryError

    opened by N-Kingsley 3
  • Problem of running the test script

    Problem of running the test script

    Hello,

    I downloaded the data with the .sh downloading script you provided, I also got an nps weights file after training. When I ran the testing command I got the following error: Traceback (most recent call last): File "/home/research/DL/VIN/pytorch-value-iteration-networks/test.py", line 158, in main(config) File "/home/research/DL/VIN/pytorch-value-iteration-networks/test.py", line 85, in main _, predictions = vin(X_in, S1_in, S2_in, config) File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 357, in call result = self.forward(*input, **kwargs) File "/home/research/DL/VIN/pytorch-value-iteration-networks/model.py", line 64, in forward return logits, self.sm(logits) File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 352, in call for hook in self._forward_pre_hooks.values(): File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 398, in getattr type(self).name, name)) AttributeError: 'Softmax' object has no attribute '_forward_pre_hooks'

    Thanks for helping!

    opened by YantianZha 3
  • Improved readability of the VIN model, in addition to minor changes

    Improved readability of the VIN model, in addition to minor changes

    My main modification is in the forward method of the model where you extract the q_out from the q values, and not repeating q = F.conv2d(...) in two places. I also made minor improvements, such as adding argparse in the dataset creation script and changing .cuda() into .to(device) in test.py.

    opened by shuishida 2
  • Inconsistent tensor sizes when starting training

    Inconsistent tensor sizes when starting training

    Hey there. I'm trying to run

    python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128
    

    But I get the following error

    Number of Train Samples: 103926
    Number of Test Samples: 17434
         Epoch | Train Loss | Train Error | Epoch Time
    Traceback (most recent call last):
      File "train.py", line 147, in <module>
        train(net, trainloader, config, criterion, optimizer, use_GPU)
      File "train.py", line 40, in train
        outputs, predictions = net(X, S1, S2, config)
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 224, in __call__
        result = self.forward(*input, **kwargs)
      File "/media/user_home2/j1k1000o/j1k/VINs/pytorch-value-iteration-networks/model.py", line 44, in forward
        q = F.conv2d(torch.cat([r, v], 1), 
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/autograd/variable.py", line 897, in cat
        return Concat.apply(dim, *iterable)
      File "/home/j1k1000o/anaconda3/lib/python3.6/site-packages/torch/autograd/_functions/tensor.py", line 317, in forward
        return torch.cat(inputs, dim)
    RuntimeError: inconsistent tensor sizes at /opt/conda/conda-bld/pytorch_1502009910772/work/torch/lib/THC/generic/THCTensorMath.cu:141
    

    I've executed

    ./download_weights_and_datasets.sh
    

    as well as

    python ./dataset/make_training_data.py
    

    And I'm running it on an Ubuntu 16.04, python 3.6 and with all the requirements installed.

    Can you help me out?

    opened by juancprzs 2
  • Don't understand VIN last step

    Don't understand VIN last step

        slice_s1 = S1.long().expand(config.imsize, 1, config.l_q, q.size(0))
        slice_s1 = slice_s1.permute(3, 2, 1, 0)
        q_out = q.gather(2, slice_s1).squeeze(2)
    

    What does this 3 lines do?

    opened by QiXuanWang 1
  • KeyError: 'arr_1 is not a file in the archive'

    KeyError: 'arr_1 is not a file in the archive'

    python3 train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128 Traceback (most recent call last): File "train.py", line 135, in config.datafile, imsize=config.imsize, train=True, transform=transform) File "/home/user/pytorch/tutorials/valueiterationnetworks/pytorch-value-iteration-networks/dataset/dataset.py", line 22, in init self._process(file, self.train) File "/home/user/pytorch/tutorials/valueiterationnetworks/pytorch-value-iteration-networks/dataset/dataset.py", line 49, in _process S1 = f['arr_1'] File "/home/user/miniconda3/lib/python3.6/site-packages/numpy/lib/npyio.py", line 255, in getitem raise KeyError("%s is not a file in the archive" % key) KeyError: 'arr_1 is not a file in the archive'

    I got this error, could you please

    opened by derelearnro 1
  • Problem of running dataset/make_training_data.py script

    Problem of running dataset/make_training_data.py script

    Hi

    When I tried to run the make_training_data.py script to generate the gridworld.npz file, I got the following error:

    FileNotFoundError: [Errno 2] No such file or directory: 'dataset/gridworld_28x28.npz'
    

    And I found that line 101 should be modified as follows:

    save_path = "gridworld_{0}x{1}".format(dom_size[0], dom_size[1])
    
    opened by ruqing00 0
Owner
Kent Sommer
Software Engineer @ Toyota Research Institute (SF Bay Area)
Kent Sommer
Implementation detail for paper "Multi-level colonoscopy malignant tissue detection with adversarial CAC-UNet"

Multi-level-colonoscopy-malignant-tissue-detection-with-adversarial-CAC-UNet Implementation detail for our paper "Multi-level colonoscopy malignant ti

CVSM Group - email: <a href=[email protected]"> 84 Nov 22, 2022
《A-CNN: Annularly Convolutional Neural Networks on Point Clouds》(2019)

A-CNN: Annularly Convolutional Neural Networks on Point Clouds Created by Artem Komarichev, Zichun Zhong, Jing Hua from Department of Computer Science

Artёm Komarichev 44 Feb 24, 2022
MGFN: Multi-Graph Fusion Networks for Urban Region Embedding was accepted by IJCAI-2022.

Multi-Graph Fusion Networks for Urban Region Embedding (IJCAI-22) This is the implementation of Multi-Graph Fusion Networks for Urban Region Embedding

202 Nov 18, 2022
Tensorflow 2 Object Detection API kurulumu, GPU desteği, custom model hazırlama

Tensorflow 2 Object Detection API Bu tutorial, TensorFlow 2.x'in kararlı sürümü olan TensorFlow 2.3'ye yöneliktir. Bu, görüntülerde / videoda nesne a

46 Nov 20, 2022
(ICCV'21) Official PyTorch implementation of Relational Embedding for Few-Shot Classification

Relational Embedding for Few-Shot Classification (ICCV 2021) Dahyun Kang, Heeseung Kwon, Juhong Min, Minsu Cho [paper], [project hompage] We propose t

Dahyun Kang 82 Dec 24, 2022
The codes and models in 'Gaze Estimation using Transformer'.

GazeTR We provide the code of GazeTR-Hybrid in "Gaze Estimation using Transformer". We recommend you to use data processing codes provided in GazeHub.

65 Dec 27, 2022
Kroomsa: A search engine for the curious

Kroomsa A search engine for the curious. It is a search algorithm designed to en

Wingify 7 Jun 20, 2022
N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting

N-HiTS: Neural Hierarchical Interpolation for Time Series Forecasting Recent progress in neural forecasting instigated significant improvements in the

Cristian Challu 82 Jan 04, 2023
Code for the paper "Functional Regularization for Reinforcement Learning via Learned Fourier Features"

Reinforcement Learning with Learned Fourier Features State-space Soft Actor-Critic Experiments Move to the state-SAC-LFF repository. cd state-SAC-LFF

Alex Li 10 Nov 11, 2022
The implementation of the CVPR2021 paper "Structure-Aware Face Clustering on a Large-Scale Graph with 10^7 Nodes"

STAR-FC This code is the implementation for the CVPR 2021 paper "Structure-Aware Face Clustering on a Large-Scale Graph with 10^7 Nodes" 🌟 🌟 . 🎓 Re

Shuai Shen 87 Dec 28, 2022
Grad2Task: Improved Few-shot Text Classification Using Gradients for Task Representation

Grad2Task: Improved Few-shot Text Classification Using Gradients for Task Representation Prerequisites This repo is built upon a local copy of transfo

Jixuan Wang 10 Sep 28, 2022
ConformalLayers: A non-linear sequential neural network with associative layers

ConformalLayers: A non-linear sequential neural network with associative layers ConformalLayers is a conformal embedding of sequential layers of Convo

Prograf-UFF 5 Sep 28, 2022
Exploiting a Zoo of Checkpoints for Unseen Tasks

Exploiting a Zoo of Checkpoints for Unseen Tasks This repo includes code to reproduce all results in the above Neurips paper, authored by Jiaji Huang,

Baidu Research 8 Sep 06, 2022
PPLNN is a Primitive Library for Neural Network is a high-performance deep-learning inference engine for efficient AI inferencing

PPLNN is a Primitive Library for Neural Network is a high-performance deep-learning inference engine for efficient AI inferencing

943 Jan 07, 2023
This is the repo for the paper `SumGNN: Multi-typed Drug Interaction Prediction via Efficient Knowledge Graph Summarization'. (published in Bioinformatics'21)

SumGNN: Multi-typed Drug Interaction Prediction via Efficient Knowledge Graph Summarization This is the code for our paper ``SumGNN: Multi-typed Drug

Yue Yu 58 Dec 21, 2022
UI2I via StyleGAN2 - Unsupervised image-to-image translation method via pre-trained StyleGAN2 network

We proposed an unsupervised image-to-image translation method via pre-trained StyleGAN2 network. paper: Unsupervised Image-to-Image Translation via Pr

208 Dec 30, 2022
Implementation of "A MLP-like Architecture for Dense Prediction"

A MLP-like Architecture for Dense Prediction (arXiv) Updates (22/07/2021) Initial release. Model Zoo We provide CycleMLP models pretrained on ImageNet

Shoufa Chen 244 Dec 27, 2022
ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators

ELECTRA Introduction ELECTRA is a method for self-supervised language representation learning. It can be used to pre-train transformer networks using

Google Research 2.1k Dec 28, 2022
Official code for Next Check-ins Prediction via History and Friendship on Location-Based Social Networks (MDM 2018)

MUC Next Check-ins Prediction via History and Friendship on Location-Based Social Networks (MDM 2018) Performance Details for Accuracy: | Dataset

Yijun Su 3 Oct 09, 2022
Code for Deterministic Neural Networks with Appropriate Inductive Biases Capture Epistemic and Aleatoric Uncertainty

Deep Deterministic Uncertainty This repository contains the code for Deterministic Neural Networks with Appropriate Inductive Biases Capture Epistemic

Jishnu Mukhoti 69 Nov 28, 2022