2.86% and 15.85% on CIFAR-10 and CIFAR-100

Overview

Shake-Shake regularization

This repository contains the code for the paper Shake-Shake regularization. This arxiv paper is an extension of Shake-Shake regularization of 3-branch residual networks which was accepted as a workshop contribution at ICLR 2017.

The code is based on fb.resnet.torch.

Table of Contents

  1. Introduction
  2. Results
  3. Usage
  4. Contact

Introduction

The method introduced in this paper aims at helping deep learning practitioners faced with an overfit problem. The idea is to replace, in a multi-branch network, the standard summation of parallel branches with a stochastic affine combination. Applied to 3-branch residual networks, shake-shake regularization improves on the best single shot published results on CIFAR-10 and CIFAR-100 by reaching test errors of 2.86% and 15.85%.

shake-shake

Figure 1: Left: Forward training pass. Center: Backward training pass. Right: At test time.

Bibtex:

@article{Gastaldi17ShakeShake,
   title = {Shake-Shake regularization},
   author = {Xavier Gastaldi},
   journal = {arXiv preprint arXiv:1705.07485},
   year = 2017,
}

Results on CIFAR-10

The base network is a 26 2x32d ResNet (i.e. the network has a depth of 26, 2 residual branches and the first residual block has a width of 32). "Shake" means that all scaling coefficients are overwritten with new random numbers before the pass. "Even" means that all scaling coefficients are set to 0.5 before the pass. "Keep" means that we keep, for the backward pass, the scaling coefficients used during the forward pass. "Batch" means that, for each residual block, we apply the same scaling coefficient for all the images in the mini-batch. "Image" means that, for each residual block, we apply a different scaling coefficient for each image in the mini-batch. The numbers in the Table below represent the average of 3 runs except for the 96d models which were run 5 times.

Forward Backward Level 26 2x32d 26 2x64d 26 2x96d 26 2x112d
Even Even n\a 4.27 3.76 3.58 -
Even Shake Batch 4.44 - -
Shake Keep Batch 4.11 - - -
Shake Even Batch 3.47 3.30 - -
Shake Shake Batch 3.67 3.07 - -
Even Shake Image 4.11 - - -
Shake Keep Image 4.09 - - -
Shake Even Image 3.47 3.20 - -
Shake Shake Image 3.55 2.98 2.86 2.821

Table 1: Error rates (%) on CIFAR-10 (Top 1 of the last epoch)

Other results

Cifar-100:
29 2x4x64d: 15.85%

Reduced CIFAR-10:
26 2x96d: 17.05%1

SVHN:
26 2x96d: 1.4%1

Reduced SVHN:
26 2x96d: 12.32%1

Usage

  1. Install fb.resnet.torch, optnet and lua-stdlib.
  2. Download Shake-Shake
git clone https://github.com/xgastaldi/shake-shake.git
  1. Copy the elements in the shake-shake folder and paste them in the fb.resnet.torch folder. This will overwrite 5 files (main.lua, train.lua, opts.lua, checkpoints.lua and models/init.lua) and add 4 new files (models/shakeshake.lua, models/shakeshakeblock.lua, models/mulconstantslices.lua and models/shakeshaketable.lua).
  2. To reproduce CIFAR-10 results (e.g. 26 2x32d "Shake-Shake-Image" ResNet) on 2 GPUs:
CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 32 -LR 0.2 -forwardShake true -backwardShake true -shakeImage true

To get comparable results using 1 GPU, please change the batch size and the corresponding learning rate:

CUDA_VISIBLE_DEVICES=0 th main.lua -dataset cifar10 -nGPU 1 -batchSize 64 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 32 -LR 0.1 -forwardShake true -backwardShake true -shakeImage true

A 26 2x96d "Shake-Shake-Image" ResNet can be trained on 2 GPUs using:

CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 96 -LR 0.2 -forwardShake true -backwardShake true -shakeImage true
  1. To reproduce CIFAR-100 results (e.g. 29 2x4x64d "Shake-Even-Image" ResNeXt) on 2 GPUs:
CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar100 -depth 29 -baseWidth 64 -groups 4 -weightDecay 5e-4 -batchSize 32 -netType shakeshake -nGPU 2 -LR 0.025 -nThreads 8 -shareGradInput true -nEpochs 1800 -lrShape cosine -forwardShake true -backwardShake false -shakeImage true

Note

Changes made to fb.resnet.torch files:

main.lua
Ln 17, 54-59, 81-100: Adds a log

train.lua
Ln 36-38 58-60 206-213: Adds the cosine learning rate function
Ln 88-89: Adds the learning rate to the elements printed on screen

opts.lua
Ln 21-64: Adds Shake-Shake options

checkpoints.lua
Ln 15-16: Adds require 'models/shakeshakeblock', 'models/shakeshaketable' and require 'std'
Ln 60-61: Avoids using the fb.resnet.torch deepcopy (it doesn't seem to be compatible with the BN in shakeshakeblock) and replaces it with the deepcopy from stdlib
Ln 67-86: Saves only the last model

models/init.lua
Ln 91-92: Adds require 'models/mulconstantslices', require 'models/shakeshakeblock' and require 'models/shakeshaketable'

The main model is in shakeshake.lua. The residual block model is in shakeshakeblock.lua. mulconstantslices.lua is just an extension of nn.mulconstant that multiplies elements of a vector with image slices of a mini-batch tensor. shakeshaketable.lua contains the method used for CIFAR-100 since the ResNeXt code uses a table implementation instead of a module version.

Reimplementations

Pytorch
https://github.com/hysts/pytorch_shake_shake

Tensorflow
https://github.com/tensorflow/models/blob/master/research/autoaugment/
https://github.com/tensorflow/tensor2tensor

Contact

xgastaldi.mba2011 at london.edu
Any discussions, suggestions and questions are welcome!

References

(1) Ekin D. Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, and Quoc V. Le. AutoAugment: Learning Augmentation Policies from Data. In arXiv:1805.09501, May 2018.

My freqtrade strategies

My freqtrade-strategies Hi there! This is repo for my freqtrade-strategies. My name is Ilya Zelenchuk, I'm a lecturer at the SPbU university (https://

171 Dec 05, 2022
The official implementation of CircleNet: Anchor-free Detection with Circle Representation, MICCAI 2030

CircleNet: Anchor-free Detection with Circle Representation The official implementation of CircleNet, MICCAI 2020 [PyTorch] [project page] [MICCAI pap

The Biomedical Data Representation and Learning Lab 45 Nov 18, 2022
NCNN implementation of Real-ESRGAN. Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration.

NCNN implementation of Real-ESRGAN. Real-ESRGAN aims at developing Practical Algorithms for General Image Restoration.

Xintao 593 Jan 03, 2023
Code for Generating Disentangled Arguments with Prompts: A Simple Event Extraction Framework that Works

GDAP Code for Generating Disentangled Arguments with Prompts: A Simple Event Extraction Framework that Works Environment Python (verified: v3.8) CUDA

45 Oct 29, 2022
PyTorch implementation of MulMON

MulMON This repository contains a PyTorch implementation of the paper: Learning Object-Centric Representations of Multi-object Scenes from Multiple Vi

NanboLi 16 Nov 03, 2022
tsflex - feature-extraction benchmarking

tsflex - feature-extraction benchmarking This repository withholds the benchmark results and visualization code of the tsflex paper and toolkit. Flow

PreDiCT.IDLab 5 Mar 25, 2022
Sudoku solver - A sudoku solver with python

sudoku_solver A sudoku solver What is Sudoku? Sudoku (Japanese: 数独, romanized: s

Sikai Lu 0 May 22, 2022
CVPR2021 Workshop - HDRUNet: Single Image HDR Reconstruction with Denoising and Dequantization.

HDRUNet [Paper Link] HDRUNet: Single Image HDR Reconstruction with Denoising and Dequantization By Xiangyu Chen, Yihao Liu, Zhengwen Zhang, Yu Qiao an

XyChen 105 Dec 20, 2022
keyframes-CNN-RNN(action recognition)

keyframes-CNN-RNN(action recognition) Environment: python=3.7 pytorch=1.2 Datasets: Following the format of UCF101 action recognition. Run steps: Mo

4 Feb 09, 2022
Deeper insights into graph convolutional networks for semi-supervised learning

deeper_insights_into_GCNs Deeper insights into graph convolutional networks for semi-supervised learning References data and utils.py come from Implem

Davidham3 17 Dec 16, 2022
A machine learning malware analysis framework for Android apps.

🕵️ A machine learning malware analysis framework for Android apps. ☢️ DroidDetective is a Python tool for analysing Android applications (APKs) for p

James Stevenson 77 Dec 27, 2022
DeOldify - A Deep Learning based project for colorizing and restoring old images (and video!)

DeOldify - A Deep Learning based project for colorizing and restoring old images (and video!)

Jason Antic 15.8k Jan 04, 2023
MinkLoc3D-SI: 3D LiDAR place recognition with sparse convolutions,spherical coordinates, and intensity

MinkLoc3D-SI: 3D LiDAR place recognition with sparse convolutions,spherical coordinates, and intensity Introduction The 3D LiDAR place recognition aim

16 Dec 08, 2022
Robbing the FED: Directly Obtaining Private Data in Federated Learning with Modified Models

Robbing the FED: Directly Obtaining Private Data in Federated Learning with Modified Models This repo contains a barebones implementation for the atta

16 Dec 04, 2022
Robust fine-tuning of zero-shot models

Robust fine-tuning of zero-shot models This repository contains code for the paper Robust fine-tuning of zero-shot models by Mitchell Wortsman*, Gabri

224 Dec 29, 2022
data/code repository of "C2F-FWN: Coarse-to-Fine Flow Warping Network for Spatial-Temporal Consistent Motion Transfer"

C2F-FWN data/code repository of "C2F-FWN: Coarse-to-Fine Flow Warping Network for Spatial-Temporal Consistent Motion Transfer" (https://arxiv.org/abs/

EKILI 46 Dec 14, 2022
Oriented Response Networks, in CVPR 2017

Oriented Response Networks [Home] [Project] [Paper] [Supp] [Poster] Torch Implementation The torch branch contains: the official torch implementation

ZhouYanzhao 217 Dec 12, 2022
This is an official pytorch implementation of Lite-HRNet: A Lightweight High-Resolution Network.

Lite-HRNet: A Lightweight High-Resolution Network Introduction This is an official pytorch implementation of Lite-HRNet: A Lightweight High-Resolution

HRNet 675 Dec 25, 2022
An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks

AnalyticMesh Analytic Marching is an exact meshing solution from neural networks. Compared to standard methods, it completely avoids geometric and top

Karbo 45 Dec 21, 2022
Convert scikit-learn models to PyTorch modules

sk2torch sk2torch converts scikit-learn models into PyTorch modules that can be tuned with backpropagation and even compiled as TorchScript. Problems

Alex Nichol 101 Dec 16, 2022