A collection of differentiable SVD methods and also the official implementation of the ICCV21 paper "Why Approximate Matrix Square Root Outperforms Accurate SVD in Global Covariance Pooling?"

Overview

Differentiable SVD

Introduction

This repository contains:

  1. The official Pytorch implementation of ICCV21 paper Why Approximate Matrix Square Root Outperforms Accurate SVD in Global Covariance Pooling?
  2. A collection of differentiable SVD methods utilized in our paper.

You can also find the presentation of our work via the slides and via the poster.

About the paper

In this paper, we investigate the reason behind why approximate matrix square root calculated via Newton-Schulz iteration outperform the accurate ones computed by SVD from the perspectives of data precision and gradient smoothness. Various remedies for computing smooth SVD gradients are investigated. We also propose a new spectral meta-layer that uses SVD in the forward pass, and Pad'e approximants in the backward propagation to compute the gradients. The results of the so-called SVD-Pad'e achieve state-of-the-art results on ImageNet and FGVC datasets.

Differentiable SVD Methods

As the backward algorithm of SVD is prone to have numerical instability, we implement a variety of end-to-end SVD methods by manipulating the backward algortihms in this repository. They include:

  • SVD-Pad'e: use Pad'e approximants to closely approximate the gradient. It is proposed in our ICCV21 paper.
  • SVD-Taylor: use Taylor polynomial to approximate the smooth gradient. It is proposed in our ICCV21 paper and the TPAMI journal.
  • SVD-PI: use Power Iteration (PI) to approximate the gradients. It is proposed in the NeurIPS19 paper.
  • SVD-Newton: use the gradient of the Newton-Schulz iteration.
  • SVD-Trunc: set a upper limit of the gradient and apply truncation.
  • SVD-TopN: select the Top-N eigenvalues and abandon the rest.
  • SVD-Original: ordinary SVD with gradient overflow check.

In the task of global covaraince pooling, the SVD-Pad'e achieves the best performances. You are free to try other methods in your research.

Implementation and Usage

The codes is modifed on the basis of iSQRT-COV.

See the requirements.txt for the specific required packages.

To train AlexNet on ImageNet, choose a spectral meta-layer in the script and run:

CUDA_VISIBLE_DEVICES=0,1 bash train_alexnet.sh

The pre-trained models of ResNet-50 with SVD-Pad'e is available via Google Drive. You can load the state dict by:

model.load_state_dict(torch.load('pade_resnet50.pth.tar'))

Citation

If you think the codes is helpful to your research, please consider citing our paper:

@inproceedings{song2021approximate,
  title={Why Approximate Matrix Square Root Outperforms Accurate SVD in Global Covariance Pooling?},
  author={Song, Yue and Sebe, Nicu and Wang, Wei},
  booktitle={ICCV},
  year={2021}
}

Contact

If you have any questions or suggestions, please feel free to contact me

[email protected]

Owner
YueSong
Ph.D. student in Computer Vision
YueSong
[AI6101] Introduction to AI & AI Ethics is a core course of MSAI, SCSE, NTU, Singapore

[AI6101] Introduction to AI & AI Ethics is a core course of MSAI, SCSE, NTU, Singapore. The repository corresponds to the AI6101 of Semester 1, AY2021-2022, starting from 08/2021. The instructors of

AccSrd 1 Sep 22, 2022
A graphical Semi-automatic annotation tool based on labelImg and Yolov5

💕YOLOV5 semi-automatic annotation tool (Based on labelImg)

EricFang 247 Jan 05, 2023
Pytorch implementation of ICASSP 2022 paper Attention Probe: Vision Transformer Distillation in the Wild

Attention Probe: Vision Transformer Distillation in the Wild Jiahao Wang, Mingdeng Cao, Shuwei Shi, Baoyuan Wu, Yujiu Yang In ICASSP 2022 This code is

IIGROUP 6 Sep 21, 2022
A toolkit for making real world machine learning and data analysis applications in C++

dlib C++ library Dlib is a modern C++ toolkit containing machine learning algorithms and tools for creating complex software in C++ to solve real worl

Davis E. King 11.6k Jan 01, 2023
ByteTrack(Multi-Object Tracking by Associating Every Detection Box)のPythonでのONNX推論サンプル

ByteTrack-ONNX-Sample ByteTrack(Multi-Object Tracking by Associating Every Detection Box)のPythonでのONNX推論サンプルです。 ONNXに変換したモデルも同梱しています。 変換自体を試したい方はByteT

KazuhitoTakahashi 16 Oct 26, 2022
Array Camera Ptychography

Array Camera Ptychography This repository provides the code for the following papers: Schulz, Timothy J., David J. Brady, and Chengyu Wang. "Photon-li

Brady lab in Optical Sciences 1 Nov 15, 2021
Codebase for "ProtoAttend: Attention-Based Prototypical Learning."

Codebase for "ProtoAttend: Attention-Based Prototypical Learning." Authors: Sercan O. Arik and Tomas Pfister Paper: Sercan O. Arik and Tomas Pfister,

47 2 May 17, 2022
PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning.

neural-combinatorial-rl-pytorch PyTorch implementation of Neural Combinatorial Optimization with Reinforcement Learning. I have implemented the basic

Patrick E. 454 Jan 06, 2023
Pretty Tensor - Fluent Neural Networks in TensorFlow

Pretty Tensor provides a high level builder API for TensorFlow. It provides thin wrappers on Tensors so that you can easily build multi-layer neural networks.

Google 1.2k Dec 29, 2022
The dataset of tweets pulling from Twitters with keyword: Hydroxychloroquine, location: US, Time: 2020

HCQ_Tweet_Dataset: FREE to Download. Keywords: HCQ, hydroxychloroquine, tweet, twitter, COVID-19 This dataset is associated with the paper "Understand

2 Mar 16, 2022
Safe Model-Based Reinforcement Learning using Robust Control Barrier Functions

README Repository containing the code for the paper "Safe Model-Based Reinforcement Learning using Robust Control Barrier Functions". Specifically, an

Yousef Emam 13 Nov 24, 2022
Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

CoProtector Code for the prototype tool in our paper "CoProtector: Protect Open-Source Code against Unauthorized Training Usage with Data Poisoning".

Zhensu Sun 1 Oct 26, 2021
Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch

Enformer - Pytorch (wip) Implementation of Enformer, Deepmind's attention network for predicting gene expression, in Pytorch. The original tensorflow

Phil Wang 235 Dec 27, 2022
Source code for Fathony, Sahu, Willmott, & Kolter, "Multiplicative Filter Networks", ICLR 2021.

Multiplicative Filter Networks This repository contains a PyTorch MFN implementation and code to perform & reproduce experiments from the ICLR 2021 pa

Bosch Research 66 Jan 04, 2023
JugLab 33 Dec 30, 2022
Experiments and examples converting Transformers to ONNX

Experiments and examples converting Transformers to ONNX This repository containes experiments and examples on converting different Transformers to ON

Philipp Schmid 4 Dec 24, 2022
Knowledge Management for Humans using Machine Learning & Tags

HyperTag HyperTag helps humans intuitively express how they think about their files using tags and machine learning.

Ravn Tech, Inc. 165 Nov 04, 2022
[CVPR 2021] Monocular depth estimation using wavelets for efficiency

Single Image Depth Prediction with Wavelet Decomposition Michaël Ramamonjisoa, Michael Firman, Jamie Watson, Vincent Lepetit and Daniyar Turmukhambeto

Niantic Labs 205 Jan 02, 2023
NeurIPS-2021: Neural Auto-Curricula in Two-Player Zero-Sum Games.

NAC Official PyTorch implementation of NAC from the paper: Neural Auto-Curricula in Two-Player Zero-Sum Games. We release code for: Gradient based ora

Xidong Feng 19 Nov 11, 2022
The materials used in the SaxonJS tutorial presented at Declarative Amsterdam, 2021

SaxonJS-Tutorial-2021, version 1.0.4 Last updated on 4 November, 2021. Table of contents Background Prerequisites Starting a web server Running a Java

Saxonica 11 Oct 23, 2022