Proximal Backpropagation - a neural network training algorithm that takes implicit instead of explicit gradient steps

Related tags

Deep Learningproxprop
Overview

Proximal Backpropagation

Proximal Backpropagation (ProxProp) is a neural network training algorithm that takes implicit instead of explicit gradient steps to update the network parameters. We have analyzed this algorithm in our ICLR 2018 paper:

Proximal Backpropagation (Thomas Frerix, Thomas Möllenhoff, Michael Moeller, Daniel Cremers; ICLR 2018) [https://arxiv.org/abs/1706.04638]

tl;dr

  • We provide a PyTorch implementation of ProxProp for Python 3 and PyTorch 1.0.1.
  • The results of our paper can be reproduced by executing the script paper_experiments.sh.
  • ProxProp is implemented as a torch.nn.Module (a 'layer') and can be combined with any other layer and first-order optimizer. While a ProxPropConv2d and a ProxPropLinear layer already exist, you can generate a ProxProp layer for your favorite linear layer with one line of code.

Installation

  1. Make sure you have a running Python 3 (tested with Python 3.7) ecosytem. We recommend that you use a conda install, as this is also the recommended option to get the latest PyTorch running. For this README and for the scripts, we assume that you have conda running with Python 3.7.
  2. Clone this repository and switch to the directory.
  3. Install the dependencies via conda install --file conda_requirements.txt and pip install -r pip_requirements.txt.
  4. Install PyTorch with magma support. We have tested our code with PyTorch 1.0.1 and CUDA 10.0. You can install this setup via
    conda install -c pytorch magma-cuda100
    conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
    
  5. (optional, but necessary to reproduce paper experiments) Download the CIFAR-10 dataset by executing get_data.sh

Training neural networks with ProxProp

ProxProp is implemented as a custom linear layer (torch.nn.Module) with its own backward pass to take implicit gradient steps on the network parameters. With this design choice it can be combined with any other layer, for which one takes explicit gradient steps. Furthermore, the resulting update direction can be used with any first-order optimizer that expects a suitable update direction in parameter space. In our paper we prove that ProxProp generates a descent direction and show experiments with Nesterov SGD and Adam.

You can use our pre-defined layers ProxPropConv2d and ProxPropLinear, corresponding to nn.Conv2d and nn.Linear, by importing

from ProxProp import ProxPropConv2d, ProxPropLinear

Besides the usual layer parameters, as detailed in the PyTorch docs, you can provide:

  • tau_prox: step size for a proximal step; default is tau_prox=1
  • optimization_mode: can be one of 'prox_exact', 'prox_cg{N}', 'gradient' for an exact proximal step, an approximate proximal step with N conjugate gradient steps and an explicit gradient step, respectively; default is optimization_mode='prox_cg1'. The 'gradient' mode is for a fair comparison with SGD, as it incurs the same overhead as the other methods in exploiting a generic implementation with the provided PyTorch API.

If you want to use ProxProp to optimize your favorite linear layer, you can generate the respective module with one line of code. As an example for the the Conv3d layer:

from ProxProp import proxprop_module_generator
ProxPropConv3d = proxprop_module_generator(torch.nn.Conv3d)

This gives you a default implementation for the approximate conjugate gradient solver, which treats all parameters as a stacked vector. If you want to use the exact solver or want to use the conjugate gradient solver more efficiently, you have to provide the respective reshaping methods to proxprop_module_generator, as this requires specific knowledge of the layer's structure and cannot be implemented generically. As a template, take a look at the ProxProp.py file, where we have done this for the ProxPropLinear layer.

By reusing the forward/backward implementations of existing PyTorch modules, ProxProp becomes readily accessible. However, we pay an overhead associated with generically constructing the backward pass using the PyTorch API. We have intentionally sided with genericity over speed.

Reproduce paper experiments

To reproduce the paper experiments execute the script paper_experiments.sh. This will run our paper's experiments, store the results in the directory paper_experiments/ and subsequently compile the results into the file paper_plots.pdf. We use an NVIDIA Titan X GPU; executing the script takes roughly 3 hours.

Acknowledgement

We want to thank Soumith Chintala for helping us track down a mysterious bug and the whole PyTorch dev team for their continued development effort and great support to the community.

Publication

If you use ProxProp, please acknowledge our paper by citing

@article{Frerix-et-al-18,
    title = {Proximal Backpropagation},
    author={Thomas Frerix, Thomas Möllenhoff, Michael Moeller, Daniel Cremers},
    journal={International Conference on Learning Representations},
    year={2018},
    url = {https://arxiv.org/abs/1706.04638}
}
Owner
Thomas Frerix
Thomas Frerix
TraSw for FairMOT - A Single-Target Attack example (Attack ID: 19; Screener ID: 24):

TraSw for FairMOT A Single-Target Attack example (Attack ID: 19; Screener ID: 24): Fig.1 Original Fig.2 Attacked By perturbing only two frames in this

Derry Lin 21 Dec 21, 2022
This repository contain code on Novelty-Driven Binary Particle Swarm Optimisation for Truss Optimisation Problems.

This repository contain code on Novelty-Driven Binary Particle Swarm Optimisation for Truss Optimisation Problems. The main directory include the code

0 Dec 23, 2021
A Survey on Deep Learning Technique for Video Segmentation

A Survey on Deep Learning Technique for Video Segmentation A Survey on Deep Learning Technique for Video Segmentation Wenguan Wang, Tianfei Zhou, Fati

Tianfei Zhou 112 Dec 12, 2022
Solving SMPL/MANO parameters from keypoint coordinates.

Minimal-IK A simple and naive inverse kinematics solver for MANO hand model, SMPL body model, and SMPL-H body+hand model. Briefly, given joint coordin

Yuxiao Zhou 305 Dec 30, 2022
A CNN implementation using only numpy. Supports multidimensional images, stride, etc.

A CNN implementation using only numpy. Supports multidimensional images, stride, etc. Speed up due to heavy use of slicing and mathematical simplification..

2 Nov 30, 2021
Codebase for the solution that won first place and was awarded the most human-like agent in the 2021 NeurIPS Competition MineRL BASALT Challenge.

KAIROS MineRL BASALT Codebase for the solution that won first place and was awarded the most human-like agent in the 2021 NeurIPS Competition MineRL B

Vinicius G. Goecks 37 Oct 30, 2022
The 1st Place Solution of the Facebook AI Image Similarity Challenge (ISC21) : Descriptor Track.

ISC21-Descriptor-Track-1st The 1st Place Solution of the Facebook AI Image Similarity Challenge (ISC21) : Descriptor Track. You can check our solution

lyakaap 75 Jan 08, 2023
Official implementation for ICDAR 2021 paper "Handwritten Mathematical Expression Recognition with Bidirectionally Trained Transformer"

Handwritten Mathematical Expression Recognition with Bidirectionally Trained Transformer Description Convert offline handwritten mathematical expressi

Wenqi Zhao 87 Dec 27, 2022
A Python library that enables ML teams to share, load, and transform data in a collaborative, flexible, and efficient way :chestnut:

Squirrel Core Share, load, and transform data in a collaborative, flexible, and efficient way What is Squirrel? Squirrel is a Python library that enab

Merantix Momentum 249 Dec 07, 2022
CenterFace(size of 7.3MB) is a practical anchor-free face detection and alignment method for edge devices.

CenterFace Introduce CenterFace(size of 7.3MB) is a practical anchor-free face detection and alignment method for edge devices. Recent Update 2019.09.

StarClouds 1.2k Dec 21, 2022
[SIGGRAPH 2020] Attribute2Font: Creating Fonts You Want From Attributes

Attr2Font Introduction This is the official PyTorch implementation of the Attribute2Font: Creating Fonts You Want From Attributes. Paper: arXiv | Rese

Yue Gao 200 Dec 15, 2022
Simple codebase for flexible neural net training

neural-modular Simple codebase for flexible neural net training. Allows for seamless exchange of models, dataset, and optimizers. Uses hydra for confi

Jannik Kossen 7 Apr 05, 2022
This is a template for the Non-autoregressive Deep Learning-Based TTS model (in PyTorch).

Non-autoregressive Deep Learning-Based TTS Template This is a template for the Non-autoregressive TTS model. It contains Data Preprocessing Pipeline D

Keon Lee 13 Dec 05, 2022
You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling

You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling Transformer-based models are widely used in natural language processi

Zhanpeng Zeng 12 Jan 01, 2023
Virtual hand gesture mouse using a webcam

NonMouse 日本語のREADMEはこちら This is an application that allows you to use your hand itself as a mouse. The program uses a web camera to recognize your han

Yuki Takeyama 55 Jan 01, 2023
HiFi-GAN: High Fidelity Denoising and Dereverberation Based on Speech Deep Features in Adversarial Networks

HiFiGAN Denoiser This is a Unofficial Pytorch implementation of the paper HiFi-GAN: High Fidelity Denoising and Dereverberation Based on Speech Deep F

Rishikesh (ऋषिकेश) 134 Dec 27, 2022
A Pytorch implementation of CVPR 2021 paper "RSG: A Simple but Effective Module for Learning Imbalanced Datasets"

RSG: A Simple but Effective Module for Learning Imbalanced Datasets (CVPR 2021) A Pytorch implementation of our CVPR 2021 paper "RSG: A Simple but Eff

120 Dec 12, 2022
[ICCV 2021] HRegNet: A Hierarchical Network for Large-scale Outdoor LiDAR Point Cloud Registration

HRegNet: A Hierarchical Network for Large-scale Outdoor LiDAR Point Cloud Registration Introduction The repository contains the source code and pre-tr

Intelligent Sensing, Perception and Computing Group 55 Dec 14, 2022
Llvlir - Low Level Variable Length Intermediate Representation

Low Level Variable Length Intermediate Representation Low Level Variable Length

Michael Clark 2 Jan 24, 2022
PyTorch implementation of our ICCV 2021 paper Intrinsic-Extrinsic Preserved GANs for Unsupervised 3D Pose Transfer.

Unsupervised_IEPGAN This is the PyTorch implementation of our ICCV 2021 paper Intrinsic-Extrinsic Preserved GANs for Unsupervised 3D Pose Transfer. Ha

25 Oct 26, 2022