Meta-Learning Sparse Implicit Neural Representations (NeurIPS 2021)

Overview

Meta-SparseINR

Official PyTorch implementation of "Meta-learning Sparse Implicit Neural Representations" (NeurIPS 2021) by Jaeho Lee*, Jihoon Tack*, Namhoon Lee, and Jinwoo Shin.

TL;DR: We develop a scalable method to learn sparse neural representations for a large set of signals.

Illustrations of (a) an implicit neural representation, (b) the standard pruning algorithm that prunes and retrains the model for each signal considered, and (c) the proposed Meta-SparseINR procedure to find a sparse initial INR, which can be trained further to fit each signal.

1. Requirements

conda create -n inrprune python=3.7
conda activate inrprune

conda install pytorch torchvision cudatoolkit=11.1 -c pytorch -c nvidia

pip install torchmeta
pip install imageio einops tensorboardX

Datasets

  • Download Imagenette and SDF file from the following page:
  • One should locate the dataset into /data folder

2. Training

Training option

The option for the training method is as follows:

  • <DATASET>: {celeba,sdf,imagenette}

Meta-SparseINR (ours)

# Train dense model first
python main.py --exp meta_baseline --epoch 150000 --data <DATASET>

# Iterative pruning (magnitude pruning)
python main.py --exp metaprune --epoch 30000 --pruner MP --amount 0.2 --data <DATASET>

Random Pruning

# Train dense model first
python main.py --exp meta_baseline --epoch 150000 --data <DATASET>

# Iterative pruning (random pruning)
python main.py --exp metaprune --epoch 30000 --pruner RP --amount 0.2 --data <DATASET>

Dense-Narrow

# Train dense model with a given width

# Shell script style
widthlist="230 206 184 164 148 132 118 106 94 84 76 68 60 54 48 44 38 34 32 28"
for width in $widthlist
do
    python main.py --exp meta_baseline --epoch 150000 --data <DATASET> --width $width --id width_$width
done

3. Evaluation

Evaluation option

The option for the training method is as follows:

  • <DATASET>: {celeba,sdf,imagenette}
  • <OPT_TYPE>: {default,two_step_sgd}, default denotes adam optimizer with 100 steps.

We assume all checkpoints are trained.

Meta-SparseINR (ours)

python eval.py --exp prune --pruner MP --data <DATASET> --opt_type <OPT_TYPE>

Baselines

# Random pruning
python eval.py --exp prune --pruner RP --data <DATASET> --opt_type <OPT_TYPE>

# Dense-Narrow
python eval.py --exp dense_narrow --data <DATASET> --opt_type <OPT_TYPE>

# MAML + One-Shot
python eval.py --exp one_shot --data <DATASET> --opt_type default

# MAML + IMP
python eval.py --exp imp --data <DATASET> --opt_type default

# Scratch
python eval.py --exp scratch --data <DATASET> --opt_type <OPT_TYPE>

4. Experimental Results

Citation

@inproceedings{lee2021meta,
  title={Meta-learning Sparse Implicit Neural Representations},
  author={Jaeho Lee and Jihoon Tack and Namhoon Lee and Jinwoo Shin},
  booktitle={Advances in Neural Information Processing Systems},
  year={2021}
}

Reference

Owner
Jaeho Lee
Postdoctoral researcher at KAIST.
Jaeho Lee
This codebase proposes modular light python and pytorch implementations of several LiDAR Odometry methods

pyLiDAR-SLAM This codebase proposes modular light python and pytorch implementations of several LiDAR Odometry methods, which can easily be evaluated

Kitware, Inc. 208 Dec 16, 2022
COPA-SSE contains crowdsourced explanations for the Balanced COPA dataset

COPA-SSE Repository for COPA-SSE: Semi-Structured Explanations for Commonsense Reasoning. COPA-SSE contains crowdsourced explanations for the Balanced

Ana Brassard 5 Jul 31, 2022
Light-Head R-CNN

Light-head R-CNN Introduction We release code for Light-Head R-CNN. This is my best practice for my research. This repo is organized as follows: light

jemmy li 835 Dec 06, 2022
TorchX is a library containing standard DSLs for authoring and running PyTorch related components for an E2E production ML pipeline.

TorchX is a library containing standard DSLs for authoring and running PyTorch related components for an E2E production ML pipeline

193 Dec 22, 2022
Minimalistic PyTorch training loop

Backbone for PyTorch training loop Will try to keep it minimalistic. pip install back from back import Bone Features Progress bar Checkpoints saving/l

Kashin 4 Jan 16, 2020
BossNAS: Exploring Hybrid CNN-transformers with Block-wisely Self-supervised Neural Architecture Search

BossNAS This repository contains PyTorch evaluation code, retraining code and pretrained models of our paper: BossNAS: Exploring Hybrid CNN-transforme

Changlin Li 127 Dec 26, 2022
This python-based package offers a way of creating a parametric OpenMC plasma source from plasma parameters.

openmc-plasma-source This python-based package offers a way of creating a parametric OpenMC plasma source from plasma parameters. The OpenMC sources a

Fusion Energy 10 Oct 18, 2022
Ensemble Learning Priors Driven Deep Unfolding for Scalable Snapshot Compressive Imaging [PyTorch]

Ensemble Learning Priors Driven Deep Unfolding for Scalable Snapshot Compressive Imaging [PyTorch] Abstract Snapshot compressive imaging (SCI) can rec

integirty 6 Nov 01, 2022
Some pvbatch (paraview) scripts for postprocessing OpenFOAM data

pvbatchForFoam Some pvbatch (paraview) scripts for postprocessing OpenFOAM data For every script there is a help message available: pvbatch pv_state_s

Morev Ilya 2 Oct 26, 2022
NaturalCC is a sequence modeling toolkit that allows researchers and developers to train custom models

NaturalCC NaturalCC is a sequence modeling toolkit that allows researchers and developers to train custom models for many software engineering tasks,

159 Dec 28, 2022
Subpopulation detection in high-dimensional single-cell data

PhenoGraph for Python3 PhenoGraph is a clustering method designed for high-dimensional single-cell data. It works by creating a graph ("network") repr

Dana Pe'er Lab 42 Sep 05, 2022
Autoregressive Models in PyTorch.

Autoregressive This repository contains all the necessary PyTorch code, tailored to my presentation, to train and generate data from WaveNet-like auto

Christoph Heindl 41 Oct 09, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 03, 2022
🐸STT integration examples

🐸 STT 0.9.x Examples These are various examples on how to use or integrate 🐸 STT using our packages. It is a good way to just try out 🐸 STT before

coqui 92 Dec 19, 2022
[AAAI 2022] Sparse Structure Learning via Graph Neural Networks for Inductive Document Classification

Sparse Structure Learning via Graph Neural Networks for inductive document classification Make graph dataset create co-occurrence graph for datasets.

16 Dec 22, 2022
[Preprint] "Chasing Sparsity in Vision Transformers: An End-to-End Exploration" by Tianlong Chen, Yu Cheng, Zhe Gan, Lu Yuan, Lei Zhang, Zhangyang Wang

Chasing Sparsity in Vision Transformers: An End-to-End Exploration Codes for [Preprint] Chasing Sparsity in Vision Transformers: An End-to-End Explora

VITA 64 Dec 08, 2022
Pure python implementation reverse-mode automatic differentiation

MiniGrad A minimal implementation of reverse-mode automatic differentiation (a.k.a. autograd / backpropagation) in pure Python. Inspired by Andrej Kar

Kenny Song 76 Sep 12, 2022
Controlling the MicriSpotAI robot from scratch

Project-MicroSpot-AI Controlling the MicriSpotAI robot from scratch Colaborators Alexander Dennis Components from MicroSpot The MicriSpotAI has the fo

Dennis Núñez-Fernández 5 Oct 20, 2022
Weight initialization schemes for PyTorch nn.Modules

nninit Weight initialization schemes for PyTorch nn.Modules. This is a port of the popular nninit for Torch7 by @kaixhin. ##Update This repo has been

Alykhan Tejani 69 Jan 26, 2021
TensorFlow implementation of AlexNet and its training and testing on ImageNet ILSVRC 2012 dataset

AlexNet training on ImageNet LSVRC 2012 This repository contains an implementation of AlexNet convolutional neural network and its training and testin

Matteo Dunnhofer 161 Nov 25, 2022