Simple and efficient RevNet-Library with DeepSpeed support

Related tags

Text Data & NLPrevlib
Overview

RevLib

Simple and efficient RevNet-Library with DeepSpeed support

Features

  • Half the constant memory usage and faster than RevNet libraries
  • Less memory than gradient checkpointing (1 * output_size instead of n_layers * output_size)
  • Same speed as activation checkpointing
  • Extensible
  • Trivial code (<100 Lines)

Getting started

Installation

python3 -m pip install revlib

Examples

iRevNet

iRevNet is not only partially reversible but instead a fully-invertible model. The source code looks complex at first glance. It also doesn't use the memory savings it could utilize, as RevNet requires custom AutoGrad functions that are hard to maintain. An iRevNet can be implemented like this using revlib:

import torch
from torch import nn
import revlib

channels = 64
channel_multiplier = 4
depth = 3
classes = 1000


# Create a basic function that's reversibly executed multiple times. (Like f() in ResNet)
def conv(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, (3, 3), padding=1)


def block_conv(in_channels, out_channels):
    return nn.Sequential(conv(in_channels, out_channels),
                         nn.Dropout(0.2),
                         nn.BatchNorm2d(out_channels),
                         nn.ReLU())


def block():
    return nn.Sequential(block_conv(channels, channels * channel_multiplier),
                         block_conv(channels * channel_multiplier, channels),
                         nn.Conv2d(channels, channels, (3, 3), padding=1))


# Create a reversible model. f() is invoked depth-times with different weights.
rev_model = revlib.ReversibleSequential(*[block() for _ in range(depth)])

# Wrap reversible model with non-reversible layers
model = nn.Sequential(conv(3, 2*channels), rev_model, conv(2 * channels, classes))

# Use it like you would a regular PyTorch model
inp = torch.randn((1, 3, 224, 224))
out = model(inp)
out.mean().backward()
assert out.size() == (1, 1000, 224, 224)

MomentumNet

MomentumNet is another recent paper that made significant advancements in the area of memory-efficient networks. They propose to use a momentum stream instead of a second model output as illustrated below: MomentumNetIllustration. Implementing that with revlib requires you to write a custom coupling operation (functional analogue to MemCNN) that merges input and output streams.

import torch
from torch import nn
import revlib

channels = 64
depth = 16
momentum_ema_beta = 0.99


# Compute y2 from x2 and f(x1) by merging x2 and f(x1) in the forward pass.
def momentum_coupling_forward(other_stream: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
    return other_stream * momentum_ema_beta + fn_out * (1 - momentum_ema_beta)


# Calculate x2 from y2 and f(x1) by manually computing the inverse of momentum_coupling_forward.
def momentum_coupling_inverse(output: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
    return (output - fn_out * (1 - momentum_ema_beta)) / momentum_ema_beta


# Pass in coupling functions which will be used instead of x2 + f(x1) and y2 - f(x1)
rev_model = revlib.ReversibleSequential(*[layer for _ in range(depth)
                                          for layer in [nn.Conv2d(channels, channels, (3, 3), padding=1),
                                                        nn.Identity()]],
                                        coupling_forward=[momentum_coupling_forward, revlib.additive_coupling_forward],
                                        coupling_inverse=[momentum_coupling_inverse, revlib.additive_coupling_inverse])

inp = torch.randn((16, channels * 2, 224, 224))
out = rev_model(inp)
assert out.size() == (16, channels * 2, 224, 224)

Reformer

Reformer uses RevNet with chunking and LSH-attention to efficiently train a transformer. Using revlib, standard implementations, such as lucidrains' Reformer, can be improved upon to use less memory. Below we're still using the basic building blocks from lucidrains' code to have a comparable model.

import torch
from torch import nn
from reformer_pytorch.reformer_pytorch import LSHSelfAttention, Chunk, FeedForward, AbsolutePositionalEmbedding
import revlib


class Reformer(torch.nn.Module):
    def __init__(self, sequence_length: int, features: int, depth: int, heads: int, bucket_size: int = 64,
                 lsh_hash_count: int = 8, ff_chunks: int = 16, input_classes: int = 256, output_classes: int = 256):
        super(Reformer, self).__init__()
        self.token_embd = nn.Embedding(input_classes, features * 2)
        self.pos_embd = AbsolutePositionalEmbedding(features * 2, sequence_length)

        self.core = revlib.ReversibleSequential(*[nn.Sequential(nn.LayerNorm(features), layer) for _ in range(depth)
                                                 for layer in
                                                 [LSHSelfAttention(features, heads, bucket_size, lsh_hash_count),
                                                  Chunk(ff_chunks, FeedForward(features, activation=nn.GELU), 
                                                        along_dim=-2)]],
                                                split_dim=-1)
        self.out_norm = nn.LayerNorm(features * 2)
        self.out_linear = nn.Linear(features * 2, output_classes)

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        return self.out_linear(self.out_norm(self.core(self.token_embd(inp) + self.pos_embd(inp))))


sequence = 1024
classes = 16
model = Reformer(sequence, 256, 6, 8, output_classes=classes)
out = model(torch.ones((16, sequence), dtype=torch.long))
assert out.size() == (16, sequence, classes)

Explanation

Most other RevNet libraries, such as MemCNN and Revtorch calculate both f() and g() in one go, to create one large computation. RevLib, on the other hand, brings Mesh TensorFlow's "reversible half residual and swap" to PyTorch. reversible_half_residual_and_swap computes only one of f() and g() and swaps the inputs and gradients. This way, the library only has to store one output as it can recover the other output during the backward pass.
Following Mesh TensorFlow's example, revlib also uses separate x1 and x2 tensors instead of concatenating and splitting at every step to reduce the cost of memory-bound operations.

RevNet's memory consumption doesn't scale with its depth, so it's significantly more memory-efficient for deep models. One problem in most implementations was that two tensors needed to be stored in the output, quadrupling the required memory. The high memory consumption rendered RevNet nearly useless for small networks, such as BERT, with its six layers.
RevLib works around this problem by storing only one output and two inputs for each forward pass, giving a model as small as BERT a >2x improvement!

Ignoring the dual-path structure of a RevNet, it usually used to be much slower than gradient checkpointing. However, RevLib uses minimal coupling functions and has no overhead between Sequence items, allowing it to train as fast as a comparable model with gradient checkpointing.

Owner
Lucas Nestler
German ai researcher
Lucas Nestler
2021 2학기 데이터크롤링 기말프로젝트

공지 주제 웹 크롤링을 이용한 취업 공고 스케줄러 스케줄 주제 정하기 코딩하기 핵심 코드 설명 + 피피티 구조 구상 // 12/4 토 피피티 + 스크립트(대본) 제작 + 녹화 // ~ 12/10 ~ 12/11 금~토 영상 편집 // ~12/11 토 웹크롤러 사람인_평균

Choi Eun Jeong 2 Aug 16, 2022
The Classical Language Toolkit

Notice: This Git branch (dev) contains the CLTK's upcoming major release (v. 1.0.0). See https://github.com/cltk/cltk/tree/master and https://docs.clt

Classical Language Toolkit 754 Jan 09, 2023
Named-entity recognition using neural networks. Easy-to-use and state-of-the-art results.

NeuroNER NeuroNER is a program that performs named-entity recognition (NER). Website: neuroner.com. This page gives step-by-step instructions to insta

Franck Dernoncourt 1.6k Dec 27, 2022
Repository for fine-tuning Transformers 🤗 based seq2seq speech models in JAX/Flax.

Seq2Seq Speech in JAX A JAX/Flax repository for combining a pre-trained speech encoder model (e.g. Wav2Vec2, HuBERT, WavLM) with a pre-trained text de

Sanchit Gandhi 21 Dec 14, 2022
Python Implementation of ``Modeling the Influence of Verb Aspect on the Activation of Typical Event Locations with BERT'' (Findings of ACL: ACL 2021)

BERT-for-Surprisal Python Implementation of ``Modeling the Influence of Verb Aspect on the Activation of Typical Event Locations with BERT'' (Findings

7 Dec 05, 2022
VD-BERT: A Unified Vision and Dialog Transformer with BERT

VD-BERT: A Unified Vision and Dialog Transformer with BERT PyTorch Code for the following paper at EMNLP2020: Title: VD-BERT: A Unified Vision and Dia

Salesforce 44 Nov 01, 2022
Crie tokens de autenticação íntegros e seguros com UToken.

UToken - Tokens seguros. UToken (ou Unhandleable Token) é uma bilioteca criada para ser utilizada na geração de tokens seguros e íntegros, ou seja, nã

Jaedson Silva 0 Nov 29, 2022
This repository contains examples of Task-Informed Meta-Learning

Task-Informed Meta-Learning This repository contains examples of Task-Informed Meta-Learning (paper). We consider two tasks: Crop Type Classification

10 Dec 19, 2022
189 Jan 02, 2023
Multilingual finetuning of Machine Translation model on low-resource languages. Project for Deep Natural Language Processing course.

Low-resource-Machine-Translation This repository contains the code for the project relative to the course Deep Natural Language Processing. The goal o

Andrea Cavallo 3 Jun 22, 2022
An IVR Chatbot which can exponentially reduce the burden of companies as well as can improve the consumer/end user experience.

IVR-Chatbot Achievements 🏆 Team Uhtred won the Maverick 2.0 Bot-a-thon 2021 organized by AbInbev India. ❓ Problem Statement As we all know that, lot

ARYAMAAN PANDEY 9 Dec 08, 2022
Large-scale pretraining for dialogue

A State-of-the-Art Large-scale Pretrained Response Generation Model (DialoGPT) This repository contains the source code and trained model for a large-

Microsoft 1.8k Jan 07, 2023
Lightweight utility tools for the detection of multiple spellings, meanings, and language-specific terminology in British and American English

Breame ( British English and American English) Breame is a lightweight Python package with a number of utility tools to aid in the detection of words

Charles 8 Oct 10, 2022
APEACH: Attacking Pejorative Expressions with Analysis on Crowd-generated Hate Speech Evaluation Datasets

APEACH - Korean Hate Speech Evaluation Datasets APEACH is the first crowd-generated Korean evaluation dataset for hate speech detection. Sentences of

Kevin-Yang 70 Dec 06, 2022
Sorce code and datasets for "K-BERT: Enabling Language Representation with Knowledge Graph",

K-BERT Sorce code and datasets for "K-BERT: Enabling Language Representation with Knowledge Graph", which is implemented based on the UER framework. R

Weijie Liu 834 Jan 09, 2023
Sentence Embeddings with BERT & XLNet

Sentence Transformers: Multilingual Sentence Embeddings using BERT / RoBERTa / XLM-RoBERTa & Co. with PyTorch This framework provides an easy method t

Ubiquitous Knowledge Processing Lab 9.1k Jan 02, 2023
🤕 spelling exceptions builder for lazy people

🤕 spelling exceptions builder for lazy people

Vlad Bokov 3 May 12, 2022
Open-Source Toolkit for End-to-End Speech Recognition leveraging PyTorch-Lightning and Hydra.

OpenSpeech provides reference implementations of various ASR modeling papers and three languages recipe to perform tasks on automatic speech recogniti

Soohwan Kim 26 Dec 14, 2022
The RWKV Language Model

RWKV-LM We propose the RWKV language model, with alternating time-mix and channel-mix layers: The R, K, V are generated by linear transforms of input,

PENG Bo 877 Jan 05, 2023
customer care chatbot made with Rasa Open Source.

Customer Care Bot Customer care bot for ecomm company which can solve faq and chitchat with users, can contact directly to team. 🛠 Features Basic E-c

Dishant Gandhi 23 Oct 27, 2022