Details about the wide minima density hypothesis and metrics to compute width of a minima

Overview

wide-minima-density-hypothesis

Details about the wide minima density hypothesis and metrics to compute width of a minima

This repo presents the wide minima density hypothesis as proposed in the following paper:

Key contributions:

  • Hypothesis about minima density
  • A SOTA LR schedule that exploits the hypothesis and beats general baseline schedules
  • Reducing wall clock training time and saving GPU compute hours with our LR schedule (Pretraining BERT-Large in 33% less training steps)
  • SOTA BLEU score on IWSLT'14 ( DE-EN )

Prerequisite:

  • CUDA, cudnn
  • Python 3.6+
  • PyTorch 1.4.0

Knee LR Schedule

Based on the density of wide vs narrow minima , we propose the Knee LR schedule that pushes generalization boundaries further by exploiting the nature of the loss landscape. The LR schedule is an explore-exploit based schedule, where the explore phase maintains a high lr for a significant time to access and land into a wide minimum with a good probability. The exploit phase is a simple linear decay scheme, which decays the lr to zero over the exploit phase. The only hyperparameter to tune is the explore epochs/steps. We have shown that 50% of the training budget allocated for explore is good enough for landing in a wider minimum and better generalization, thus removing the need for hyperparameter tuning.

  • Note that many experiments require warmup, which is done in the initial phase of training for a fixed number of steps and is usually required for Adam based optimizers/ large batch training. It is complementary with the Knee schedule and can be added to it.

To use the Knee Schedule, import the scheduler into your training file:

>>> from knee_lr_schedule import KneeLRScheduler
>>> scheduler = KneeLRScheduler(optimizer, peak_lr, warmup_steps, explore_steps, total_steps)

To use it during training :

>>> model.train()
>>> output = model(inputs)
>>> loss = criterion(output, targets)
>>> loss.backward()
>>> optimizer.step()
>>> scheduler.step()

Details about args:

  • optimizer: optimizer needed for training the model ( SGD/Adam )
  • peak_lr: the peak learning required for explore phase to escape narrow minimas
  • warmup_steps: steps required for warmup( usually needed for adam optimizers/ large batch training) Default value: 0
  • explore_steps: total steps for explore phase.
  • total_steps: total training budget steps for training the model

Measuring width of a minima

Keskar et.al 2016 (https://arxiv.org/abs/1609.04836) argue that wider minima generalize much better than sharper minima. The computation method in their work uses the compute expensive LBFGS-B second order method, which is hard to scale. We use a projected gradient ascent based method, which is first order in nature and very easy to implement/use. Here is a simple way you can compute the width of the minima your model finds during training:

>>> from minima_width_compute import ComputeKeskarSharpness
>>> cks = ComputeKeskarSharpness(model_final_ckpt, optimizer, criterion, trainloader, epsilon, lr, max_steps)
>>> width = cks.compute_sharpness()

Details about args:

  • model_final_ckpt: model loaded with the saved checkpoint after final training step
  • optimizer : optimizer to use for projected gradient ascent ( SGD, Adam )
  • criterion : criterion for computing loss (e.g. torch.nn.CrossEntropyLoss)
  • trainloader : iterator over the training dataset (torch.utils.data.DataLoader)
  • epsilon : epsilon value determines the local boundary around which minima witdh is computed (Default value : 1e-4)
  • lr : lr for the optimizer to perform projected gradient ascent ( Default: 0.001)
  • max_steps : max steps to compute the width (Default: 1000). Setting it too low could lead to the gradient ascent method not converging to an optimal point.

The above default values have been chosen after tuning and observing the loss values of projected gradient ascent on Cifar-10 with ResNet-18 and SGD-Momentum optimizer, as mentioned in our paper. The values may vary for experiments with other optimizers/datasets/models. Please tune them for optimal convergence.

  • Acknowledgements: We would like to thank Harshay Shah (https://github.com/harshays) for his helpful discussions for computing the width of the minima.

Citation

Please cite our paper in your publications if you use our work:

@article{iyer2020wideminima,
  title={Wide-minima Density Hypothesis and the Explore-Exploit Learning Rate Schedule},
  author={Iyer, Nikhil and Thejas, V and Kwatra, Nipun and Ramjee, Ramachandran and Sivathanu, Muthian},
  journal={arXiv preprint arXiv:2003.03977},
  year={2020}
}
  • Note: This work was done during an internship at Microsoft Research India
Owner
Nikhil Iyer
Studied at BITS-Pilani Hyderabad Campus. AI Research @ Jio-Haptik. Ex Microsoft Research India
Nikhil Iyer
An open source python library for automated feature engineering

"One of the holy grails of machine learning is to automate more and more of the feature engineering process." ― Pedro Domingos, A Few Useful Things to

alteryx 6.4k Jan 03, 2023
Multi-Horizon-Forecasting-for-Limit-Order-Books

Multi-Horizon-Forecasting-for-Limit-Order-Books This jupyter notebook is used to demonstrate our work, Multi-Horizon Forecasting for Limit Order Books

Zihao Zhang 116 Dec 23, 2022
An experiment to bait a generalized frontrunning MEV bot

Honeypot 🍯 A simple experiment that: Creates a honeypot contract Baits a generalized fronturnning bot with a unique transaction Analyze bot behaviour

0x1355 14 Nov 24, 2022
Retinal vessel segmentation based on GT-UNet

Retinal vessel segmentation based on GT-UNet Introduction This project is a retinal blood vessel segmentation code based on UNet-like Group Transforme

Kent0n 27 Dec 18, 2022
CSKG is a commonsense knowledge graph that combines seven popular sources into a consolidated representation

CSKG: The CommonSense Knowledge Graph CSKG is a commonsense knowledge graph that combines seven popular sources into a consolidated representation: AT

USC ISI I2 85 Dec 12, 2022
Theano is a Python library that allows you to define, optimize, and evaluate mathematical expressions involving multi-dimensional arrays efficiently. It can use GPUs and perform efficient symbolic differentiation.

============================================================================================================ `MILA will stop developing Theano https:

9.6k Dec 31, 2022
Flower classification model that classifies flowers in 10 classes made using transfer learning (~85% accuracy).

flower-classification-inceptionV3 Flower classification model that classifies flowers in 10 classes. Training and validation are done using a pre-anot

Ivan R. Mršulja 1 Dec 12, 2021
Scalable, Portable and Distributed Gradient Boosting (GBDT, GBRT or GBM) Library, for Python, R, Java, Scala, C++ and more. Runs on single machine, Hadoop, Spark, Dask, Flink and DataFlow

eXtreme Gradient Boosting Community | Documentation | Resources | Contributors | Release Notes XGBoost is an optimized distributed gradient boosting l

Distributed (Deep) Machine Learning Community 23.6k Dec 31, 2022
efficient neural audio synthesis in the waveform domain

neural waveshaping synthesis real-time neural audio synthesis in the waveform domain paper • website • colab • audio by Ben Hayes, Charalampos Saitis,

Ben Hayes 169 Dec 23, 2022
Space Ship Simulator using python

FlyOver Basic space-ship simulator using python How to run? Just double click run.py What modules do i need? All modules that i currently using is bui

0 Oct 09, 2022
Residual Dense Net De-Interlace Filter (RDNDIF)

Residual Dense Net De-Interlace Filter (RDNDIF) Work in progress deep de-interlacer filter. It is based on the architecture proposed by Bernasconi et

Louis 7 Feb 15, 2022
A project studying the influence of communication in multi-objective normal-form games

Communication in Multi-Objective Normal-Form Games This repo consists of five different types of agents that we have used in our study of communicatio

Willem Röpke 0 Dec 17, 2021
This project is the PyTorch implementation of our CVPR 2022 paper:

Requirements and Dependency Install PyTorch with CUDA (for GPU). (Experiments are validated on python 3.8.11 and pytorch 1.7.0) (For visualization if

Lei Huang 23 Nov 29, 2022
A collection of easy-to-use, ready-to-use, interesting deep neural network models

Interesting and reproducible research works should be conserved. This repository wraps a collection of deep neural network models into a simple and un

Aria Ghora Prabono 16 Jun 16, 2022
SE-MSCNN: A Lightweight Multi-scaled Fusion Network for Sleep Apnea Detection Using Single-Lead ECG Signals

SE-MSCNN: A Lightweight Multi-scaled Fusion Network for Sleep Apnea Detection Using Single-Lead ECG Signals Abstract Sleep apnea (SA) is a common slee

9 Dec 21, 2022
CyTran: Cycle-Consistent Transformers for Non-Contrast to Contrast CT Translation

CyTran: Cycle-Consistent Transformers for Non-Contrast to Contrast CT Translation We propose a novel approach to translate unpaired contrast computed

Nicolae Catalin Ristea 13 Jan 02, 2023
Official implementation of "Intrinsic Dimension, Persistent Homology and Generalization in Neural Networks", NeurIPS 2021.

PHDimGeneralization Official implementation of "Intrinsic Dimension, Persistent Homology and Generalization in Neural Networks", NeurIPS 2021. Overvie

Tolga Birdal 13 Nov 08, 2022
piSTAR Lab is a modular platform built to make AI experimentation accessible and fun. (pistar.ai)

piSTAR Lab WARNING: This is an early release. Overview piSTAR Lab is a modular deep reinforcement learning platform built to make AI experimentation a

piSTAR Lab 0 Aug 01, 2022
Model Zoo for AI Model Efficiency Toolkit

We provide a collection of popular neural network models and compare their floating point and quantized performance.

Qualcomm Innovation Center 137 Jan 03, 2023