InvTorch: memory-efficient models with invertible functions

Related tags

Deep Learninginvtorch
Overview

InvTorch: Memory-Efficient Invertible Functions

This module extends the functionality of torch.utils.checkpoint.checkpoint to work with invertible functions. So, not only the intermediate activations will be released from memory. The input tensors get deallocated and recomputed later using the inverse function only in the backward pass. This is useful in extreme situations where more compute is traded with memory. However, there are few caveats to consider which are detailed here.

Installation

InvTorch has minimal dependencies. It only requires PyTorch version 1.10.0 or later.

conda install pytorch==1.10.0 torchvision torchaudio cudatoolkit=11.3 -c pytorch
pip install invtorch

Basic Usage

The main module that we are interested in is InvertibleModule which inherits from torch.nn.Module. Subclass it to implement your own invertible code.

import torch
from torch import nn
from invtorch import InvertibleModule


class InvertibleLinear(InvertibleModule):
    def __init__(self, in_features, out_features):
        super().__init__(invertible=True, checkpoint=True)
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.randn(out_features))

    def function(self, inputs):
        outputs = inputs @ self.weight.T + self.bias
        requires_grad = self.do_require_grad(inputs, self.weight, self.bias)
        return outputs.requires_grad_(requires_grad)

    def inverse(self, outputs):
        return (outputs - self.bias) @ self.weight.T.pinverse()

Structure

You can immediately notice few differences to the regular PyTorch module here. There is no longer a need to define forward(). Instead, it is replaced with function(*inputs). Additionally, it is necessary to define its inverse function as inverse(*outputs). Both methods can only take one or more positional arguments and return a torch.Tensor or a tuple of outputs which can have anything including tensors.

Requires Gradient

function() must manually call .requires_grad_(True/False) on all output tensors. The forward pass is run in no_grad mode and there is no way to detect which output need gradients without tracing. It is possible to infer this from requires_grad values of the inputs and self.parameters(). The above code uses do_require_grad() which returns True if any input did require gradient.

Example

Now, this model is ready to be instantiated and used directly.

x = torch.randn(10, 3)
model = InvertibleLinear(3, 5)
print('Is invertible:', model.check_inverse(x))

y = model(x)
print('Output requires_grad:', y.requires_grad)
print('Input was freed:', x.storage().size() == 0)

y.backward(torch.randn_like(y))
print('Input was restored:', x.storage().size() != 0)

Checkpoint and Invertible Modes

InvertibleModule has two flags which control the mode of operation; checkpoint and invertible. If checkpoint was set to False, or when working in no_grad mode, or no input or parameter has requires_grad set to True, it acts exactly as a normal PyTorch module. Otherwise, the model is either invertible or an ordinary checkpoint depending on whether invertible is set to True or False, respectively. Those, flags can be changed at any time during operation without any repercussions.

Limitations

Under the hood, InvertibleModule uses invertible_checkpoint(); a low-level implementation which allows it to function. There are few considerations to keep in mind when working with invertible checkpoints and non-materialized tensors. Please, refer to the documentation in the code for more details.

Overriding forward()

Although forward() is now doing important things to ensure the validity of the results when calling invertible_checkpoint(), it can still be overridden. The main reason of doing so is to provide a more user-friendly interface; function signature and output format. For example, function() could return extra outputs that are not needed in the module outputs but are essential for correctly computing the inverse(). In such case, define forward() to wrap outputs = super().forward(*inputs) more cleanly.

TODOs

Here are few feature ideas that could be implemented to enrich the utility of this package:

  • Add more basic operations and modules
  • Add coupling and interleave -based invertible operations
  • Add more checks to help the user in debugging more features
  • Allow picking some inputs to not be freed in invertible mode
  • Context-manager to temporarily change the mode of operation
  • Implement dynamic discovery for outputs that requires_grad
  • Develop an automatic mode optimization for a network for various objectives
You might also like...
A memory-efficient implementation of DenseNets

efficient_densenet_pytorch A PyTorch =1.0 implementation of DenseNets, optimized to save GPU memory. Recent updates Now works on PyTorch 1.0! It uses

Official and maintained implementation of the paper
Official and maintained implementation of the paper "OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data" [BMVC 2021].

OSS-Net: Memory Efficient High Resolution Semantic Segmentation of 3D Medical Data Christoph Reich, Tim Prangemeier, Özdemir Cetin & Heinz Koeppl | Pr

Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch

Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo

Implementation of Memory-Efficient Neural Networks with Multi-Level Generation, ICCV 2021
Implementation of Memory-Efficient Neural Networks with Multi-Level Generation, ICCV 2021

Memory-Efficient Multi-Level In-Situ Generation (MLG) By Jiaqi Gu, Hanqing Zhu, Chenghao Feng, Mingjie Liu, Zixuan Jiang, Ray T. Chen and David Z. Pan

Memory-efficient optimum einsum using opt_einsum planning and PyTorch kernels.

opt-einsum-torch There have been many implementations of Einstein's summation. numpy's numpy.einsum is the least efficient one as it only runs in sing

Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

XtremeDistil framework for distilling/compressing massive multilingual neural network models to tiny and efficient models for AI at scale

XtremeDistilTransformers for Distilling Massive Multilingual Neural Networks ACL 2020 Microsoft Research [Paper] [Video] Releasing [XtremeDistilTransf

Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.
Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.

InfoPro-Pytorch The Information Propagation algorithm for training deep networks with local supervision. (ICLR 2021) Revisiting Locally Supervised Lea

Efficient-GlobalPointer - Pytorch Efficient GlobalPointer
Efficient-GlobalPointer - Pytorch Efficient GlobalPointer

引言 感谢苏神带来的模型,原文地址:https://spaces.ac.cn/archives/8877 如何运行 对应模型EfficientGlobalPoi

Releases(v0.5.0)
Owner
Modar M. Alfadly
Deep learning researcher interested in understanding neural networks
Modar M. Alfadly
Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems

Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems This is our experimental code for RecSys 2021 paper "Learning

11 Jul 28, 2022
Revitalizing CNN Attention via Transformers in Self-Supervised Visual Representation Learning

Revitalizing CNN Attention via Transformers in Self-Supervised Visual Representation Learning

ChongjianGE 89 Dec 02, 2022
Official PyTorch implementation of the paper Image-Based CLIP-Guided Essence Transfer.

TargetCLIP- official pytorch implementation of the paper Image-Based CLIP-Guided Essence Transfer This repository finds a global direction in StyleGAN

Hila Chefer 221 Dec 13, 2022
Moving Object Segmentation in 3D LiDAR Data: A Learning-based Approach Exploiting Sequential Data

LiDAR-MOS: Moving Object Segmentation in 3D LiDAR Data This repo contains the code for our paper: Moving Object Segmentation in 3D LiDAR Data: A Learn

Photogrammetry & Robotics Bonn 394 Dec 29, 2022
python debugger and anti-vm that checks if you're in a virtual machine or if someones trying to debug your file

Anti-Debug was made by Love ❌ code ✅ 🎉 ・What it checks for ・ Kills tools that can be used to debug your file ・ Exits if ran in vm (supports different

Rdimo 31 Aug 09, 2022
Model-based Reinforcement Learning Improves Autonomous Racing Performance

Racing Dreamer: Model-based versus Model-free Deep Reinforcement Learning for Autonomous Racing Cars In this work, we propose to learn a racing contro

Cyber Physical Systems - TU Wien 38 Dec 06, 2022
The implemetation of Dynamic Nerual Garments proposed in Siggraph Asia 2021

DynamicNeuralGarments Introduction This repository contains the implemetation of Dynamic Nerual Garments proposed in Siggraph Asia 2021. ./GarmentMoti

42 Dec 27, 2022
This project helps to colorize grayscale images using multiple exemplars.

Multiple Exemplar-based Deep Colorization (Pytorch Implementation) Pretrained Model [Jitendra Chautharia](IIT Jodhpur)1,3, Prerequisites Python 3.6+ N

jitendra chautharia 3 Aug 05, 2022
Texture mapping with variational auto-encoders

vae-textures This is an experiment with using variational autoencoders (VAEs) to perform mesh parameterization. This was also my first project using J

Alex Nichol 41 May 24, 2022
[MedIA2021]MIDeepSeg: Minimally Interactive Segmentation of Unseen Objects from Medical Images Using Deep Learning

MIDeepSeg: Minimally Interactive Segmentation of Unseen Objects from Medical Images Using Deep Learning [MedIA or Arxiv] and [Demo] This repository pr

Healthcare Intelligence Laboratory 92 Dec 08, 2022
Official code of our work, AVATAR: A Parallel Corpus for Java-Python Program Translation.

AVATAR Official code of our work, AVATAR: A Parallel Corpus for Java-Python Program Translation. AVATAR stands for jAVA-pyThon progrAm tRanslation. AV

Wasi Ahmad 26 Dec 03, 2022
Implementation of "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement" by pytorch

This repository is used to suspend the results of our paper "A Deep Learning Loss Function based on Auditory Power Compression for Speech Enhancement"

ScorpioMiku 19 Sep 30, 2022
PyTorch code for MART: Memory-Augmented Recurrent Transformer for Coherent Video Paragraph Captioning

MART: Memory-Augmented Recurrent Transformer for Coherent Video Paragraph Captioning PyTorch code for our ACL 2020 paper "MART: Memory-Augmented Recur

Jie Lei 雷杰 151 Jan 06, 2023
This project aims to segment 4 common retinal lesions from Fundus Images.

This project aims to segment 4 common retinal lesions from Fundus Images.

Husam Nujaim 1 Oct 10, 2021
Deep Q Learning with OpenAI Gym and Pokemon Showdown

pokemon-deep-learning An openAI gym project for pokemon involving deep q learning. Made by myself, Sam Little, and Layton Webber. This code captures g

2 Dec 22, 2021
HyperSeg: Patch-wise Hypernetwork for Real-time Semantic Segmentation Official PyTorch Implementation

: We present a novel, real-time, semantic segmentation network in which the encoder both encodes and generates the parameters (weights) of the decoder. Furthermore, to allow maximal adaptivity, the w

Yuval Nirkin 182 Dec 14, 2022
Attention over nodes in Graph Neural Networks using PyTorch (NeurIPS 2019)

Intro This repository contains code to generate data and reproduce experiments from our NeurIPS 2019 paper: Boris Knyazev, Graham W. Taylor, Mohamed R

Boris Knyazev 242 Jan 06, 2023
Auto-Lama combines object detection and image inpainting to automate object removals

Auto-Lama Auto-Lama combines object detection and image inpainting to automate object removals. It is build on top of DE:TR from Facebook Research and

44 Dec 09, 2022
Official code release for ICCV 2021 paper SNARF: Differentiable Forward Skinning for Animating Non-rigid Neural Implicit Shapes.

Official code release for ICCV 2021 paper SNARF: Differentiable Forward Skinning for Animating Non-rigid Neural Implicit Shapes.

235 Dec 26, 2022
This repository contains the implementation of the HealthGen model, a generative model to synthesize realistic EHR time series data with missingness

HealthGen: Conditional EHR Time Series Generation This repository contains the implementation of the HealthGen model, a generative model to synthesize

0 Jan 20, 2022