🛠️ Tools for Transformers compression using Lightning ⚡

Overview

Hits

Bert-squeeze

Bert-squeeze is a repository aiming to provide code to reduce the size of Transformer-based models or decrease their latency at inference time.

It gathers a non-exhaustive list of techniques such as distillation, pruning, quantization, early-exiting. The repo is written using PyTorch Lightning and Transformers.

About the project

As a heavy user of transformer-based models (which are truly amazing from my point of view) I always struggled to put those heavy models in production while having a decent inference speed. There are of course a bunch of existing libraries to optimize and compress transformer-based models (ONNX , distiller, compressors , KD_Lib, ... ).
I started this project because of the need to reduce the latency of models integrating transformers as subcomponents. For this reason, this project aims at providing implementations to train various transformer-based models (and others) using PyTorch Lightning but also to distill, prune, and quantize models.
I chose to write this repo with Lightning because of its growing trend, its flexibility, and the very few repositories using it. It currently only handles sequence classification models, but support for other tasks and custom architectures is planned.

Installation

First download the repository:

git clone https://github.com/JulesBelveze/bert-squeeze.git

and then install dependencies using poetry:

poetry install

You are all set!

Quickstarts

You can find a bunch of already prepared configurations under the examples folder. Just choose the one you need and run the following:

python3 -m bert-squeeze.main -cp=examples -cn=wanted_config

Disclaimer: I have not extensively tested all procedures and thus do not guarantee the performance of every implemented method.

Concepts

Transformers

If you never heard of it then I can only recommend you to read this amazing blog post and if you want to dig deeper there is this awesome lecture was given by Stanford available here.

Distillation

The idea of distillation is to train a small network to mimic a big network by trying to replicate its outputs. The repository provides the ability to transfer knowledge from any model to any other (if you need a model that is not within the models folder just write your own).

The repository also provides the possibility to perform soft-distillation or hard-distillation on an unlabeled dataset. In the soft case, we use the probabilities of the teacher as a target. In the hard one, we assume that the teacher's predictions are the actual label.

You can find these implementations under the distillation/ folder.

Quantization

Neural network quantization is the process of reducing the weights precision in the neural network. The repo has two callbacks one for dynamic quantization and one for quantization-aware training (using the Lightning callback) .

You can find those implementations under the utils/callbacks/ folder.

Pruning

Pruning neural networks consist of removing weights from trained models to compress them. This repo features various pruning implementations and methods such as head-pruning, layer dropping, and weights dropping.

You can find those implementations under the utils/callbacks/ folder.

Contributions and questions

If you are missing a feature that could be relevant to this repo, or a bug that you noticed feel free to open a PR or open an issue. As you can see in the roadmap there are a bunch more features to come 😃

Also, if you have any questions or suggestions feel free to ask!

References

  1. Alammar, J (2018). The Illustrated Transformer [Blog post]. Retrieved from https://jalammar.github.io/illustrated-transformer/
  2. stanfordonline (2021) Stanford CS224N NLP with Deep Learning | Winter 2021 | Lecture 9 - Self- Attention and Transformers. [online video] Available at: https://www.youtube.com/watch?v=ptuGllU5SQQ
  3. Thomas Wolf and Lysandre Debut and Victor Sanh and Julien Chaumond and Clement Delangue and Anthony Moi and Pierric Cistac and Tim Rault and Rémi Louf and Morgan Funtowicz and Jamie Brew (2019). HuggingFace's Transformers: State-of-the-art Natural Language Processing
  4. Hassan Sajjad and Fahim Dalvi and Nadir Durrani and Preslav Nakov (2020). Poor Man's BERT Smaller and Faster Transformer Models
  5. Angela Fan and Edouard Grave and Armand Joulin (2019). Reducing Transformer Depth on Demand with Structured Dropout
  6. Paul Michel and Omer Levy and Graham Neubig (2019). Are Sixteen Heads Really Better than One?
  7. Fangxiaoyu Feng and Yinfei Yang and Daniel Cer and Naveen Arivazhagan and Wei Wang (2020). Language-agnostic BERT Sentence Embedding
Owner
Jules Belveze
AI craftsman | NLP | MLOps
Jules Belveze
A Python library created to assist programmers with complex mathematical functions

libmaths libmaths was created not only as a learning experience for me, but as a way to make mathematical models in seconds for Python users using mat

Simple 73 Oct 02, 2022
Code for "Neural 3D Scene Reconstruction with the Manhattan-world Assumption" CVPR 2022 Oral

News 05/10/2022 To make the comparison on ScanNet easier, we provide all quantitative and qualitative results of baselines here, including COLMAP, COL

ZJU3DV 365 Dec 30, 2022
[WACV 2022] Contextual Gradient Scaling for Few-Shot Learning

CxGrad - Official PyTorch Implementation Contextual Gradient Scaling for Few-Shot Learning Sanghyuk Lee, Seunghyun Lee, and Byung Cheol Song In WACV 2

Sanghyuk Lee 4 Dec 05, 2022
A general, feasible, and extensible framework for classification tasks.

Pytorch Classification A general, feasible and extensible framework for 2D image classification. Features Easy to configure (model, hyperparameters) T

Eugene 26 Nov 22, 2022
Official implementation for Scale-Aware Neural Architecture Search for Multivariate Time Series Forecasting

1 SNAS4MTF This repo is the official implementation for Scale-Aware Neural Architecture Search for Multivariate Time Series Forecasting. 1.1 The frame

SZJ 5 Sep 21, 2022
Python implementation of "Elliptic Fourier Features of a Closed Contour"

PyEFD An Python/NumPy implementation of a method for approximating a contour with a Fourier series, as described in [1]. Installation pip install pyef

Henrik Blidh 71 Dec 09, 2022
A script written in Python that returns a consensus string and profile matrix of a given DNA string(s) in FASTA format.

A script written in Python that returns a consensus string and profile matrix of a given DNA string(s) in FASTA format.

Zain 1 Feb 01, 2022
How to Become More Salient? Surfacing Representation Biases of the Saliency Prediction Model

How to Become More Salient? Surfacing Representation Biases of the Saliency Prediction Model

Bogdan Kulynych 49 Nov 05, 2022
This reposityory contains the PyTorch implementation of our paper "Generative Dynamic Patch Attack".

Generative Dynamic Patch Attack This reposityory contains the PyTorch implementation of our paper "Generative Dynamic Patch Attack". Requirements PyTo

Xiang Li 8 Nov 17, 2022
🗣️ Microsoft Edge TTS for Home Assistant, no need for app_key

Microsoft Edge TTS for Home Assistant This component is based on the TTS service of Microsoft Edge browser, no need to apply for app_key. Install Down

152 Dec 31, 2022
Official implementation of "Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets" (CVPR2021)

Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets This is the official implementation of "Towards Good Pract

Sanja Fidler's Lab 52 Nov 22, 2022
Specificity-preserving RGB-D Saliency Detection

Specificity-preserving RGB-D Saliency Detection Authors: Tao Zhou, Huazhu Fu, Geng Chen, Yi Zhou, Deng-Ping Fan, and Ling Shao. 1. Preface This reposi

Tao Zhou 35 Jan 08, 2023
Code for sound field predictions in domains with impedance boundaries. Used for generating results from the paper

Code for sound field predictions in domains with impedance boundaries. Used for generating results from the paper

DTU Acoustic Technology Group 11 Dec 17, 2022
Build fully-functioning computer vision models with PyTorch

Detecto is a Python package that allows you to build fully-functioning computer vision and object detection models with just 5 lines of code. Inferenc

Alan Bi 576 Dec 29, 2022
Introducing neural networks to predict stock prices

IntroNeuralNetworks in Python: A Template Project IntroNeuralNetworks is a project that introduces neural networks and illustrates an example of how o

Vivek Palaniappan 637 Jan 04, 2023
Pytorch implementation of our paper accepted by NeurIPS 2021 -- Revisiting Discriminator in GAN Compression: A Generator-discriminator Cooperative Compression Scheme

Revisiting Discriminator in GAN Compression: A Generator-discriminator Cooperative Compression Scheme (NeurIPS2021) (Link) Overview Prerequisites Linu

Shaojie Li 34 Mar 31, 2022
X-VLM: Multi-Grained Vision Language Pre-Training

X-VLM: learning multi-grained vision language alignments Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts. Yan Zeng, Xi

Yan Zeng 286 Dec 23, 2022
Public repo for the ICCV2021-CVAMD paper "Is it Time to Replace CNNs with Transformers for Medical Images?"

Is it Time to Replace CNNs with Transformers for Medical Images? Accepted at ICCV-2021: Workshop on Computer Vision for Automated Medical Diagnosis (C

Christos Matsoukas 80 Dec 27, 2022
BoxInst: High-Performance Instance Segmentation with Box Annotations

Introduction This repository is the code that needs to be submitted for OpenMMLab Algorithm Ecological Challenge, the paper is BoxInst: High-Performan

88 Dec 21, 2022
3rd Place Solution for ICCV 2021 Workshop SSLAD Track 3A - Continual Learning Classification Challenge

Online Continual Learning via Multiple Deep Metric Learning and Uncertainty-guided Episodic Memory Replay 3rd Place Solution for ICCV 2021 Workshop SS

Rifki Kurniawan 6 Nov 10, 2022