DiffStride: Learning strides in convolutional neural networks

Overview

DiffStride: Learning strides in convolutional neural networks

Overview

DiffStride is a pooling layer with learnable strides. Unlike strided convolutions, average pooling or max-pooling that require cross-validating stride values at each layer, DiffStride can be initialized with an arbitrary value at each layer (e.g. (2, 2) and during training its strides will be optimized for the task at hand.

We describe DiffStride in our ICLR 2022 paper Learning Strides in Convolutional Neural Network. Compared to the experiments described in the paper, this implementation uses a Pre-Act Resnet and uses Mixup in training.

Installation

To install the diffstride library, run the following pip git clone this repo:

git clone https://github.com/google-research/diffstride.git

The cd into the root and run the command:

pip install -e .

Example training

To run an example training on CIFAR10 and save the result in TensorBoard:

python3 -m diffstride.examples.main \
  --gin_config=cifar10.gin \
  --gin_bindings="train.workdir = '/tmp/exp/diffstride/resnet18/'"

Using custom parameters

This implementation uses Gin to parametrize the model, data processing and training loop. To use custom parameters, one should edit examples/cifar10.gin.

For example, to train with SpectralPooling on cifar100:

data.load_datasets:
  name = 'cifar100'

resnet.Resnet:
  pooling_cls = @pooling.FixedSpectralPooling

Or to train with strided convolutions and without Mixup:

data.load_datasets:
  mixup_alpha = 0.0

resnet.Resnet:
  pooling_cls = None

Results

This current implementation gives the following accuracy on CIFAR-10 and CIFAR-100, averaged over three runs. To show the robustness of DiffStride to stride initialization, we run both with the standard strides of ResNet (resnet.resnet18.strides = '1, 1, 2, 2, 2') and with a 'poor' choice of strides (resnet.resnet18.strides = '1, 1, 3, 2, 3'). Unlike Strided Convolutions and fixed Spectral Pooling, DiffStride is not affected by the stride initialization.

CIFAR-10

Pooling Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2) Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)
Strided Convolution (Baseline) 91.06 ± 0.04 89.21 ± 0.27
Spectral Pooling 93.49 ± 0.05 92.00 ± 0.08
DiffStride 94.20 ± 0.06 94.19 ± 0.15

CIFAR-100

Pooling Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2) Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)
Strided Convolution (Baseline) 65.75 ± 0.39 60.82 ± 0.42
Spectral Pooling 72.86 ± 0.23 67.74 ± 0.43
DiffStride 76.08 ± 0.23 76.09 ± 0.06

CPU/GPU Warning

We rely on the tensorflow FFT implementation which requires the input data to be in the channels_first format. This is usually not the regular data format of most datasets (including CIFAR) and running with channels_first also prevents from using of convolutions on CPU. Therefore even if we do support channels_last data format for CPU compatibility , we do encourage the user to run with channels_first data format on GPU.

Reference

If you use this repository, please consider citing:

@article{riad2022diffstride,
  title={Learning Strides in Convolutional Neural Networks},
  author={Riad, Rachid and Teboul, Olivier and Grangier, David and Zeghidour, Neil},
  journal={ICLR},
  year={2022}
}

Disclainer

This is not an official Google product.

Owner
Google Research
Google Research
An atmospheric growth and evolution model based on the EVo degassing model and FastChem 2.0

EVolve Linking planetary mantles to atmospheric chemistry through volcanism using EVo and FastChem. Overview EVolve is a linked mantle degassing and a

Pip Liggins 2 Jan 17, 2022
Awesome Graph Classification - A collection of important graph embedding, classification and representation learning papers with implementations.

A collection of graph classification methods, covering embedding, deep learning, graph kernel and factorization papers

Benedek Rozemberczki 4.5k Jan 01, 2023
Towards Open-World Feature Extrapolation: An Inductive Graph Learning Approach

This repository holds the implementation for paper Towards Open-World Feature Extrapolation: An Inductive Graph Learning Approach Download our preproc

Qitian Wu 42 Dec 27, 2022
Fast Neural Representations for Direct Volume Rendering

Fast Neural Representations for Direct Volume Rendering Sebastian Weiss, Philipp Hermüller, Rüdiger Westermann This repository contains the code and s

Sebastian Weiss 20 Dec 03, 2022
Face Synthetics dataset is a collection of diverse synthetic face images with ground truth labels.

The Face Synthetics dataset Face Synthetics dataset is a collection of diverse synthetic face images with ground truth labels. It was introduced in ou

Microsoft 608 Jan 02, 2023
PyTorch implementation of saliency map-aided GAN for Auto-demosaic+denosing

Saiency Map-aided GAN for RAW2RGB Mapping The PyTorch implementations and guideline for Saiency Map-aided GAN for RAW2RGB Mapping. 1 Implementations B

Yuzhi ZHAO 20 Oct 24, 2022
Geometric Vector Perceptron --- a rotation-equivariant GNN for learning from biomolecular structure

Geometric Vector Perceptron Code to accompany Learning from Protein Structure with Geometric Vector Perceptrons by B Jing, S Eismann, P Suriana, RJL T

Dror Lab 85 Dec 29, 2022
DTCN SMP Challenge - Sequential prediction learning framework and algorithm

DTCN This is the implementation of our paper "Sequential Prediction of Social Me

Bobby 2 Jan 24, 2022
3D Avatar Lip Syncronization from speech (JALI based face-rigging)

visemenet-inference Inference Demo of "VisemeNet-tensorflow" VisemeNet is an audio-driven animator centric speech animation driving a JALI or standard

Junhwan Jang 17 Dec 20, 2022
This project intends to use SVM supervised learning to determine whether or not an individual is diabetic given certain attributes.

Diabetes Prediction Using SVM I explore a diabetes prediction algorithm using a Diabetes dataset. Using a Support Vector Machine for my prediction alg

Jeff Shen 1 Jan 14, 2022
A library for differentiable nonlinear optimization.

Theseus A library for differentiable nonlinear optimization built on PyTorch to support constructing various problems in robotics and vision as end-to

Meta Research 1.1k Dec 30, 2022
Codes and scripts for "Explainable Semantic Space by Grounding Languageto Vision with Cross-Modal Contrastive Learning"

Visually Grounded Bert Language Model This repository is the official implementation of Explainable Semantic Space by Grounding Language to Vision wit

17 Dec 17, 2022
Attention-based CNN-LSTM and XGBoost hybrid model for stock prediction

Attention-based CNN-LSTM and XGBoost hybrid model for stock prediction Requirements The code has been tested running under Python 3.7.4, with the foll

zshicode 84 Jan 01, 2023
[ICLR 2021 Spotlight Oral] "Undistillable: Making A Nasty Teacher That CANNOT teach students", Haoyu Ma, Tianlong Chen, Ting-Kuei Hu, Chenyu You, Xiaohui Xie, Zhangyang Wang

Undistillable: Making A Nasty Teacher That CANNOT teach students "Undistillable: Making A Nasty Teacher That CANNOT teach students" Haoyu Ma, Tianlong

VITA 71 Dec 28, 2022
Learning What and Where to Draw

###Learning What and Where to Draw Scott Reed, Zeynep Akata, Santosh Mohan, Samuel Tenka, Bernt Schiele, Honglak Lee This is the code for our NIPS 201

Scott Ellison Reed 337 Nov 18, 2022
Finetuning Pipeline

KLUE Baseline Korean(한국어) KLUE-baseline contains the baseline code for the Korean Language Understanding Evaluation (KLUE) benchmark. See our paper fo

74 Dec 13, 2022
StarGAN-ZSVC: Unofficial PyTorch Implementation

This repository is an unofficial PyTorch implementation of StarGAN-ZSVC by Matthew Baas and Herman Kamper. This repository provides both model architectures and the code to inference or train them.

Jirayu Burapacheep 11 Aug 28, 2022
SSD-based Object Detection in PyTorch

SSD-based Object Detection in PyTorch 서강대학교 현대모비스 SW 프로그램에서 진행한 인공지능 프로젝트입니다. Jetson nano를 이용해 pre-trained network를 fine tuning시켜 차량 및 신호등 인식을 구현하였습니다

Haneul Kim 1 Nov 16, 2021
Attack on Confidence Estimation algorithm from the paper "Disrupting Deep Uncertainty Estimation Without Harming Accuracy"

Attack on Confidence Estimation (ACE) This repository is the official implementation of "Disrupting Deep Uncertainty Estimation Without Harming Accura

3 Mar 30, 2022
CTRMs: Learning to Construct Cooperative Timed Roadmaps for Multi-agent Path Planning in Continuous Spaces

CTRMs: Learning to Construct Cooperative Timed Roadmaps for Multi-agent Path Planning in Continuous Spaces This is a repository for the following pape

17 Oct 13, 2022