InvTorch: memory-efficient models with invertible functions

Related tags

Deep Learninginvtorch
Overview

InvTorch: Memory-Efficient Invertible Functions

This module extends the functionality of torch.utils.checkpoint.checkpoint to work with invertible functions. So, not only the intermediate activations will be released from memory. The input tensors get deallocated and recomputed later using the inverse function only in the backward pass. This is useful in extreme situations where more compute is traded with memory. However, there are few caveats to consider which are detailed here.

Installation

InvTorch has minimal dependencies. It only requires PyTorch version 1.10.0 or later.

conda install pytorch==1.10.0 torchvision torchaudio cudatoolkit=11.3 -c pytorch
pip install invtorch

Basic Usage

The main module that we are interested in is InvertibleModule which inherits from torch.nn.Module. Subclass it to implement your own invertible code.

import torch
from torch import nn
from invtorch import InvertibleModule


class InvertibleLinear(InvertibleModule):
    def __init__(self, in_features, out_features):
        super().__init__(invertible=True, checkpoint=True)
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.randn(out_features))

    def function(self, inputs):
        outputs = inputs @ self.weight.T + self.bias
        requires_grad = self.do_require_grad(inputs, self.weight, self.bias)
        return outputs.requires_grad_(requires_grad)

    def inverse(self, outputs):
        return (outputs - self.bias) @ self.weight.T.pinverse()

Structure

You can immediately notice few differences to the regular PyTorch module here. There is no longer a need to define forward(). Instead, it is replaced with function(*inputs). Additionally, it is necessary to define its inverse function as inverse(*outputs). Both methods can only take one or more positional arguments and return a torch.Tensor or a tuple of outputs which can have anything including tensors.

Requires Gradient

function() must manually call .requires_grad_(True/False) on all output tensors. The forward pass is run in no_grad mode and there is no way to detect which output need gradients without tracing. It is possible to infer this from requires_grad values of the inputs and self.parameters(). The above code uses do_require_grad() which returns True if any input did require gradient.

Example

Now, this model is ready to be instantiated and used directly.

x = torch.randn(10, 3)
model = InvertibleLinear(3, 5)
print('Is invertible:', model.check_inverse(x))

y = model(x)
print('Output requires_grad:', y.requires_grad)
print('Input was freed:', x.storage().size() == 0)

y.backward(torch.randn_like(y))
print('Input was restored:', x.storage().size() != 0)

Checkpoint and Invertible Modes

InvertibleModule has two flags which control the mode of operation; checkpoint and invertible. If checkpoint was set to False, or when working in no_grad mode, or no input or parameter has requires_grad set to True, it acts exactly as a normal PyTorch module. Otherwise, the model is either invertible or an ordinary checkpoint depending on whether invertible is set to True or False, respectively. Those, flags can be changed at any time during operation without any repercussions.

Limitations

Under the hood, InvertibleModule uses invertible_checkpoint(); a low-level implementation which allows it to function. There are few considerations to keep in mind when working with invertible checkpoints and non-materialized tensors. Please, refer to the documentation in the code for more details.

Overriding forward()

Although forward() is now doing important things to ensure the validity of the results when calling invertible_checkpoint(), it can still be overridden. The main reason of doing so is to provide a more user-friendly interface; function signature and output format. For example, function() could return extra outputs that are not needed in the module outputs but are essential for correctly computing the inverse(). In such case, define forward() to wrap outputs = super().forward(*inputs) more cleanly.

TODOs

Here are few feature ideas that could be implemented to enrich the utility of this package:

  • Add more basic operations and modules
  • Add coupling and interleave -based invertible operations
  • Add more checks to help the user in debugging more features
  • Allow picking some inputs to not be freed in invertible mode
  • Context-manager to temporarily change the mode of operation
  • Implement dynamic discovery for outputs that requires_grad
  • Develop an automatic mode optimization for a network for various objectives
You might also like...
A memory-efficient implementation of DenseNets

efficient_densenet_pytorch A PyTorch =1.0 implementation of DenseNets, optimized to save GPU memory. Recent updates Now works on PyTorch 1.0! It uses

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

Implementation of Memory-Efficient Neural Networks with Multi-Level Generation, ICCV 2021
Implementation of Memory-Efficient Neural Networks with Multi-Level Generation, ICCV 2021

Memory-Efficient Multi-Level In-Situ Generation (MLG) By Jiaqi Gu, Hanqing Zhu, Chenghao Feng, Mingjie Liu, Zixuan Jiang, Ray T. Chen and David Z. Pan

Memory-efficient optimum einsum using opt_einsum planning and PyTorch kernels.

opt-einsum-torch There have been many implementations of Einstein's summation. numpy's numpy.einsum is the least efficient one as it only runs in sing

Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

XtremeDistil framework for distilling/compressing massive multilingual neural network models to tiny and efficient models for AI at scale

XtremeDistilTransformers for Distilling Massive Multilingual Neural Networks ACL 2020 Microsoft Research [Paper] [Video] Releasing [XtremeDistilTransf

Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.
Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.

InfoPro-Pytorch The Information Propagation algorithm for training deep networks with local supervision. (ICLR 2021) Revisiting Locally Supervised Lea

Efficient-GlobalPointer - Pytorch Efficient GlobalPointer
Efficient-GlobalPointer - Pytorch Efficient GlobalPointer

引言 感谢苏神带来的模型,原文地址:https://spaces.ac.cn/archives/8877 如何运行 对应模型EfficientGlobalPoi

Releases(v0.5.0)
Owner
Modar M. Alfadly
Deep learning researcher interested in understanding neural networks
Modar M. Alfadly
3D Generative Adversarial Network

Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling This repository contains pre-trained models and sampling

Chengkai Zhang 791 Dec 20, 2022
[CVPR2021] Invertible Image Signal Processing

Invertible Image Signal Processing This repository includes official codes for "Invertible Image Signal Processing (CVPR2021)". Figure: Our framework

Yazhou XING 281 Dec 31, 2022
PyTorch-Geometric Implementation of MarkovGNN: Graph Neural Networks on Markov Diffusion

MarkovGNN This is the official PyTorch-Geometric implementation of MarkovGNN paper under the title "MarkovGNN: Graph Neural Networks on Markov Diffusi

HipGraph: High-Performance Graph Analytics and Learning 6 Sep 23, 2022
Official PyTorch implementation of "ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows"

ArtFlow Official PyTorch implementation of the paper: ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows Jie An*, Siyu Huang*, Yibing

123 Dec 27, 2022
Code for CVPR2021 "Visualizing Adapted Knowledge in Domain Transfer". Visualization for domain adaptation. #explainable-ai

Visualizing Adapted Knowledge in Domain Transfer @inproceedings{hou2021visualizing, title={Visualizing Adapted Knowledge in Domain Transfer}, auth

Yunzhong Hou 80 Dec 25, 2022
Video Frame Interpolation with Transformer (CVPR2022)

VFIformer Official PyTorch implementation of our CVPR2022 paper Video Frame Interpolation with Transformer Dependencies python = 3.8 pytorch = 1.8.0

DV Lab 63 Dec 16, 2022
Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy" (ICLR 2022 Spotlight)

About Code release for Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy (ICLR 2022 Spotlight)

THUML @ Tsinghua University 221 Dec 31, 2022
Code and data for "TURL: Table Understanding through Representation Learning"

TURL This Repo contains code and data for "TURL: Table Understanding through Representation Learning". Environment and Setup Data Pretraining Finetuni

SunLab-OSU 63 Nov 23, 2022
Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more

Apache MXNet (incubating) for Deep Learning Apache MXNet is a deep learning framework designed for both efficiency and flexibility. It allows you to m

The Apache Software Foundation 20.2k Jan 08, 2023
PyTorch implementation of U-TAE and PaPs for satellite image time series panoptic segmentation.

Panoptic Segmentation of Satellite Image Time Series with Convolutional Temporal Attention Networks (ICCV 2021) This repository is the official implem

71 Jan 04, 2023
✅ How Robust are Fact Checking Systems on Colloquial Claims?. In NAACL-HLT, 2021.

How Robust are Fact Checking Systems on Colloquial Claims? Official PyTorch implementation of our NAACL paper: Byeongchang Kim*, Hyunwoo Kim*, Seokhee

Byeongchang Kim 19 Mar 15, 2022
Hcaptcha-challenger - Gracefully face hCaptcha challenge with Yolov5(ONNX) embedded solution

hCaptcha Challenger 🚀 Gracefully face hCaptcha challenge with Yolov5(ONNX) embe

593 Jan 03, 2023
Learning cell communication from spatial graphs of cells

ncem Features Repository for the manuscript Fischer, D. S., Schaar, A. C. and Theis, F. Learning cell communication from spatial graphs of cells. 2021

Theis Lab 77 Dec 30, 2022
CV backbones including GhostNet, TinyNet and TNT, developed by Huawei Noah's Ark Lab.

CV Backbones including GhostNet, TinyNet, TNT (Transformer in Transformer) developed by Huawei Noah's Ark Lab. GhostNet Code TinyNet Code TNT Code Pyr

HUAWEI Noah's Ark Lab 3k Jan 08, 2023
Time-series-deep-learning - Developing Deep learning LSTM, BiLSTM models, and NeuralProphet for multi-step time-series forecasting of stock price.

Stock Price Prediction Using Deep Learning Univariate Time Series Predicting stock price using historical data of a company using Neural networks for

Abdultawwab Safarji 7 Nov 27, 2022
Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation

UTNet (Accepted at MICCAI 2021) Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation Introduction Transf

110 Jan 01, 2023
Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.

Neural Spatio-Temporal Point Processes [arxiv] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel Abstract. We propose a new class of parameterizations

Facebook Research 75 Dec 19, 2022
PyTorch Implementation of "Non-Autoregressive Neural Machine Translation"

Non-Autoregressive Transformer Code release for Non-Autoregressive Neural Machine Translation by Jiatao Gu, James Bradbury, Caiming Xiong, Victor O.K.

Salesforce 261 Nov 12, 2022
face2comics by Sxela (Alex Spirin) - face2comics datasets

This is a paired face to comics dataset, which can be used to train pix2pix or similar networks.

Alex 164 Nov 13, 2022
Simple image captioning model - CLIP prefix captioning.

Simple image captioning model - CLIP prefix captioning.

688 Jan 04, 2023