Code for Piggyback: Adapting a Single Network to Multiple Tasks by Learning to Mask Weights

Overview

Piggyback: https://arxiv.org/abs/1801.06519

Pretrained masks and backbones are available here: https://uofi.box.com/s/c5kixsvtrghu9yj51yb1oe853ltdfz4q

Datasets in PyTorch format are available here: https://uofi.box.com/s/ixncr3d85guosajywhf7yridszzg5zsq
All rights belong to the respective publishers. The datasets are provided only to aid reproducibility.

The PyTorch-friendly Places365 dataset can be downloaded from http://places2.csail.mit.edu/download.html

Place masks in checkpoints/ and unzipped datasets in data/

VGG-16 ResNet-50 DenseNet-121
CUBS 20.75 18.23 19.24
Stanford Cars 11.78 10.19 10.62
Flowers 6.93 4.77 4.91
WikiArt 29.80 28.57 29.33
Sketch 22.30 19.75 20.05

Note that the numbers in the paper are averaged over multiple runs for each ordering of datasets. These numbers were obtained by evaluating the models on a Titan X (Pascal). Note that numbers on other GPUs might be slightly different (~0.1%) owing to cudnn algorithm selection. https://discuss.pytorch.org/t/slightly-different-results-on-k-40-v-s-titan-x/10064

Requirements:

Python 2.7 or 3.xx
torch==0.2.0.post3
torchvision==0.1.9
torchnet (pip install git+https://github.com/pytorch/[email protected])
tqdm (pip install tqdm)

Run all code from the src/ directory, e.g. ./scripts/run_piggyback_training.sh

Training:

Check out src/scripts/run_piggyback_training.sh.

This script uses the default hyperparams and trains a model as described in the paper. The best performing model on the val set is saved to disk. This saved model includes the real-valued mask weights.

By default, we use the models provided by torchvision as our backbone networks. If you intend to evaluate with the masks provided by us, please use the correct version of torch and torchvision. In case you want to use a different version, but still want to use our masks, then download the pytorch_backbone networks provided in the box link above. Make appropriate changes to your pytorch code to load those backbone models.

Saving trained masks only.

Check out src/scripts/run_packing.sh.

This extracts the binary/ternary masks from the above trained models, and saves them separately.

Eval:

Use the saved masks, apply them to a backbone network and run eval.

By default, our backbone models are those provided with torchvision.
Note that to replicate our results, you have to use the package versions specified above.
Newer package versions might have different weights for the backbones, and the provided masks won't work.

cd src  # Run everything from src/

CUDA_VISIBLE_DEVICES=0 python pack.py --mode eval --dataset flowers \
  --arch vgg16 \
  --maskloc ../checkpoints/vgg16_binary.pt
Owner
Arun Mallya
NVIDIA Research
Arun Mallya
A Simple Long-Tailed Rocognition Baseline via Vision-Language Model

BALLAD This is the official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model. Requirements Python3 Pytorch(1.7.

Teli Ma 4 Jan 20, 2022
Full Stack Deep Learning Labs

Full Stack Deep Learning Labs Welcome! Project developed during lab sessions of the Full Stack Deep Learning Bootcamp. We will build a handwriting rec

Full Stack Deep Learning 1.2k Dec 31, 2022
Stacked Recurrent Hourglass Network for Stereo Matching

SRH-Net: Stacked Recurrent Hourglass Introduction This repository is supplementary material of our RA-L submission, which helps reviewers to understan

28 Jan 03, 2023
Predicting a person's gender based on their weight and height

Logistic Regression Advanced Case Study Gender Classification: Predicting a person's gender based on their weight and height 1. Introduction We turn o

1 Feb 01, 2022
Transfer Learning Remote Sensing

Transfer_Learning_Remote_Sensing Simulation R codes for data generation and visualizations are in the folder simulation. Experiment: California Housin

2 Jun 21, 2022
Bringing Computer Vision and Flutter together , to build an awesome app !!

Bringing Computer Vision and Flutter together , to build an awesome app !! Explore the Directories Flutter · Machine Learning Table of Contents About

Padmanabha Banerjee 14 Apr 07, 2022
ML-Decoder: Scalable and Versatile Classification Head

ML-Decoder: Scalable and Versatile Classification Head Paper Official PyTorch Implementation Tal Ridnik, Gilad Sharir, Avi Ben-Cohen, Emanuel Ben-Baru

189 Jan 04, 2023
This is an official repository of CLGo: Learning to Predict 3D Lane Shape and Camera Pose from a Single Image via Geometry Constraints

CLGo This is an official repository of CLGo: Learning to Predict 3D Lane Shape and Camera Pose from a Single Image via Geometry Constraints An earlier

刘芮金 32 Dec 20, 2022
Line-level Handwritten Text Recognition (HTR) system implemented with TensorFlow.

Line-level Handwritten Text Recognition with TensorFlow This model is an extended version of the Simple HTR system implemented by @Harald Scheidl and

Hoàng Tùng Lâm (Linus) 72 May 07, 2022
MINERVA: An out-of-the-box GUI tool for offline deep reinforcement learning

MINERVA is an out-of-the-box GUI tool for offline deep reinforcement learning, designed for everyone including non-programmers to do reinforcement learning as a tool.

Takuma Seno 80 Nov 06, 2022
Code repo for "FASA: Feature Augmentation and Sampling Adaptation for Long-Tailed Instance Segmentation" (ICCV 2021)

FASA: Feature Augmentation and Sampling Adaptation for Long-Tailed Instance Segmentation (ICCV 2021) This repository contains the implementation of th

Yuhang Zang 21 Dec 17, 2022
simple_pytorch_example project is a toy example of a python script that instantiates and trains a PyTorch neural network on the FashionMNIST dataset

simple_pytorch_example project is a toy example of a python script that instantiates and trains a PyTorch neural network on the FashionMNIST dataset

Ramón Casero 1 Jan 07, 2022
AlgoVision - A Framework for Differentiable Algorithms and Algorithmic Supervision

NeurIPS 2021 Paper "Learning with Algorithmic Supervision via Continuous Relaxations"

Felix Petersen 76 Jan 01, 2023
Official implement of "CAT: Cross Attention in Vision Transformer".

CAT: Cross Attention in Vision Transformer This is official implement of "CAT: Cross Attention in Vision Transformer". Abstract Since Transformer has

100 Dec 15, 2022
This repository contains datasets and baselines for benchmarking Chinese text recognition.

Benchmarking-Chinese-Text-Recognition This repository contains datasets and baselines for benchmarking Chinese text recognition. Please see the corres

FudanVI Lab 254 Dec 30, 2022
The repository offers the official implementation of our paper in PyTorch.

Cloth Interactive Transformer (CIT) Cloth Interactive Transformer for Virtual Try-On Bin Ren1, Hao Tang1, Fanyang Meng2, Runwei Ding3, Ling Shao4, Phi

Bingoren 49 Dec 01, 2022
验证码识别 深度学习 tensorflow 神经网络

captcha_tf2 验证码识别 深度学习 tensorflow 神经网络 使用卷积神经网络,对字符,数字类型验证码进行识别,tensorflow使用2.0以上 目前项目还在更新中,诸多bug,欢迎提出issue和PR, 希望和你一起共同完善项目。 实例demo 训练过程 优化器选择: Adam

5 Apr 28, 2022
A Topic Modeling toolbox

Topik A Topic Modeling toolbox. Introduction The aim of topik is to provide a full suite and high-level interface for anyone interested in applying to

Anaconda, Inc. (formerly Continuum Analytics, Inc.) 93 Dec 01, 2022
Local Similarity Pattern and Cost Self-Reassembling for Deep Stereo Matching Networks

Local Similarity Pattern and Cost Self-Reassembling for Deep Stereo Matching Networks Contributions A novel pairwise feature LSP to extract structural

31 Dec 06, 2022
Code for the published paper : Learning to recognize rare traffic sign

Improving traffic sign recognition by active search This repo contains code for the paper : "Learning to recognise rare traffic signs" How to use this

samsja 4 Jan 05, 2023