[CVPR 2021] "The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models" Tianlong Chen, Jonathan Frankle, Shiyu Chang, Sijia Liu, Yang Zhang, Michael Carbin, Zhangyang Wang

Overview

The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models

License: MIT

Codes for this paper The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models. [CVPR 2021]

Tianlong Chen, Jonathan Frankle, Shiyu Chang, Sijia Liu, Yang Zhang, Michael Carbin, Zhangyang Wang.

Overview

Can we aggressively trim down the complexity of pre-trained models, without damaging their downstream transferability?

Transfer Learning for Winning Tickets from Supervised and Self-supervised Pre-training

Downstream classification tasks.

Downstream detection and segmentation tasks.

Properties of Pre-training Tickets

Reproduce

Preliminary

Required environment:

  • pytorch >= 1.5.0
  • torchvision

Pre-trained Models

Pre-trained models are provided here.

imagenet_weight.pt # torchvision std model

moco.pt # pretrained moco v2 model (only contain encorder_q)

moco_v2_800ep_pretrain.pth.tar # pretrained moco v2 model (contain encorder_q&k)

simclr_weight.pt # (pretrained_simclr weight)

Task-Specific Tickets Finding

Remark. for both pre-training tasks and downstream tasks.

Iterative Magnitude Pruning

SimCLR task
cd SimCLR 
python -u main.py \
    [experiment name] \ 
    --gpu 0,1,2,3 \    
    --epochs 180 \
    --prun_epoch 10 \ # pruning for ( 1 + 180/10 iterations)
    --prun_percent 0.2 \
    --lr 1e-4 \
    --arch resnet50 \
    --batch_size 256 \
    --data [data direction] \
    --sim_model [pretrained_simclr_model] \
    --save_dir simclr_imp
MoCo task
cd MoCo
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u main_moco_imp.py \
	[Dataset Direction] \
	--pretrained_path [pretrained_moco_model] \
    -a resnet50 \
    --batch-size 256 \
    --dist-url 'tcp://127.0.0.1:5234' \
    --multiprocessing-distributed \
    --world-size 1 \
    --rank 0 \
    --mlp \
    --moco-t 0.2 \
    --aug-plus \
    --cos \
    --epochs 180 \
    --retrain_epoch 10 \ # pruning for ( 1 + 180/10 iterations)
    --save_dir moco_imp
Classification task on ImageNet
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u main_imp_imagenet.py \
	[Dataset Direction] \
	-a resnet50 \
	--epochs 10 \
	-b 256 \
	--lr 1e-4 \
	--states 19 \ # iterative pruning times 
	--save_dir imagenet_imp
Classification task on Visda2017
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u main_imp_visda.py \
	[Dataset Direction] \
	-a resnet50 \
	--epochs 20 \
	-b 256 \
	--lr 0.001 \
	--prune_type lt \ # lt or pt_trans
	--pre_weight [pretrained weight] \ # if pt_trans else None
	--states 19 \ # iterative pruning times
	--save_dir visda_imp
Classification task on small dataset
CUDA_VISIBLE_DEVICES=0 python -u main_imp_downstream.py \
	--data [dataset direction] \
	--dataset [dataset name] \#cifar10, cifar100, svhn, fmnist 
	--arch resnet50 \
	--pruning_times 19 \
	--prune_type [lt, pt, rewind_lt, pt_trans] \
	--save_dir imp_downstream \
	# --pretrained [pretrained weight if prune_type==pt_trans] \
	# --random_prune [if using random pruning] \
    # --rewind_epoch [rewind weight epoch if prune_type==rewind_lt] \

Transfer to Downstream Tasks

Small datasets: (e.g., CIFAR-10, CIFAR-100, SVHN, Fashion-MNIST)
CUDA_VISIBLE_DEVICES=0 python -u main_eval_downstream.py \
	--data [dataset direction] \
	--dataset [dataset name] \#cifar10, cifar100, svhn, fmnist 
	--arch resnet50 \
	--save_dir [save_direction] \
	--pretrained [init weight] \
	--dict_key state_dict [ dict_key in pretrained file, None means load all ] \
	--mask_dir [mask for ticket] \
	--reverse_mask \ #if want to reverse mask
Visda2017:
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u main_eval_visda.py \
	[data direction] \
	-a resnet50 \
	--epochs 20 \
	-b 256 \
	--lr 0.001 \
	--save_dir [save_direction] \
	--pretrained [init weight] \
	--dict_key state_dict [ dict_key in pretrained file, None means load all ] \
	--mask_dir [mask for ticket] \
	--reverse_mask \ #if want to reverse mask

Detection and Segmentation Experiments

Detials of YOLOv4 for detection are collected here.

Detials of DeepLabv3+ for segmentation are collected here.

Citation

@article{chen2020lottery,
  title={The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models},
  author={Chen, Tianlong and Frankle, Jonathan and Chang, Shiyu and Liu, Sijia and Zhang, Yang and Carbin, Michael and Wang, Zhangyang},
  journal={arXiv preprint arXiv:2012.06908},
  year={2020}
}

Acknowledgement

https://github.com/google-research/simclr

https://github.com/facebookresearch/moco

https://github.com/VainF/DeepLabV3Plus-Pytorch

https://github.com/argusswift/YOLOv4-pytorch

https://github.com/yczhang1017/SSD_resnet_pytorch/tree/master

Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
Official Repository for our ICCV2021 paper: Continual Learning on Noisy Data Streams via Self-Purified Replay

Continual Learning on Noisy Data Streams via Self-Purified Replay This repository contains the official PyTorch implementation for our ICCV2021 paper.

Jinseo Jeong 22 Nov 23, 2022
Machine Unlearning with SISA

Machine Unlearning with SISA Lucas Bourtoule, Varun Chandrasekaran, Christopher Choquette-Choo, Hengrui Jia, Adelin Travers, Baiwu Zhang, David Lie, N

CleverHans Lab 70 Jan 01, 2023
learned_optimization: Training and evaluating learned optimizers in JAX

learned_optimization: Training and evaluating learned optimizers in JAX learned_optimization is a research codebase for training learned optimizers. I

Google 533 Dec 30, 2022
A PyTorch implementation of a Factorization Machine module in cython.

fmpytorch A library for factorization machines in pytorch. A factorization machine is like a linear model, except multiplicative interaction terms bet

Jack Hessel 167 Jul 06, 2022
Activity tragle - Google is tracking everything, we just look at it

activity_tragle Google is tracking everything, we just look at it here. You need

BERNARD Guillaume 1 Feb 15, 2022
Code for the Population-Based Bandits Algorithm, presented at NeurIPS 2020.

Population-Based Bandits (PB2) Code for the Population-Based Bandits (PB2) Algorithm, from the paper Provably Efficient Online Hyperparameter Optimiza

Jack Parker-Holder 22 Nov 16, 2022
Python script that allows you to automatically setup your Growtopia server.

AutoSetup Python script that allows you to automatically setup your Growtopia server. How To Use Firstly, install all the required modules that used i

Aspire 3 Mar 06, 2022
《Improving Unsupervised Image Clustering With Robust Learning》(2020)

Improving Unsupervised Image Clustering With Robust Learning This repo is the PyTorch codes for "Improving Unsupervised Image Clustering With Robust L

Sungwon Park 129 Dec 27, 2022
HHP-Net: A light Heteroscedastic neural network for Head Pose estimation with uncertainty

HHP-Net: A light Heteroscedastic neural network for Head Pose estimation with uncertainty Giorgio Cantarini, Francesca Odone, Nicoletta Noceti, Federi

18 Aug 02, 2022
(NeurIPS 2021) Pytorch implementation of paper "Re-ranking for image retrieval and transductive few-shot classification"

SSR (NeurIPS 2021) Pytorch implementation of paper "Re-ranking for image retrieval and transductivefew-shot classification" [Paper] [Project webpage]

xshen 29 Dec 06, 2022
Project repo for Learning Category-Specific Mesh Reconstruction from Image Collections

Learning Category-Specific Mesh Reconstruction from Image Collections Angjoo Kanazawa*, Shubham Tulsiani*, Alexei A. Efros, Jitendra Malik University

438 Dec 22, 2022
Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation

Unseen Object Clustering: Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation Introduction In this work, we propose a new method

NVIDIA Research Projects 132 Dec 13, 2022
SphereFace: Deep Hypersphere Embedding for Face Recognition

SphereFace: Deep Hypersphere Embedding for Face Recognition By Weiyang Liu, Yandong Wen, Zhiding Yu, Ming Li, Bhiksha Raj and Le Song License SphereFa

Weiyang Liu 1.5k Dec 29, 2022
PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation

PyGRANSO PyGRANSO: A PyTorch-enabled port of GRANSO with auto-differentiation Please check https://ncvx.org/PyGRANSO for detailed instructions (introd

SUN Group @ UMN 26 Nov 16, 2022
Simple STAC Catalogs discovery tool.

STAC Catalog Discovery Simple STAC discovery tool. Just paste the STAC Catalog link and press Enter. Details STAC Discovery tool enables discovering d

Mykola Kozyr 21 Oct 19, 2022
网络协议2天集训

网络协议2天集训 抓包工具安装 Wireshark wireshark下载地址 Tcpdump CentOS yum install tcpdump -y Ubuntu apt-get install tcpdump -y k8s抓包测试环境 查看虚拟网卡veth pair 查看

120 Dec 12, 2022
The implementation of FOLD-R++ algorithm

FOLD-R-PP The implementation of FOLD-R++ algorithm. The target of FOLD-R++ algorithm is to learn an answer set program for a classification task. Inst

13 Dec 23, 2022
Database Reasoning Over Text project for ACL paper

Database Reasoning over Text This repository contains the code for the Database Reasoning Over Text paper, to appear at ACL2021. Work is performed in

Facebook Research 320 Dec 12, 2022
Codebase for BMVC 2021 paper "Text Based Person Search with Limited Data"

Text Based Person Search with Limited Data This is the codebase for our BMVC 2021 paper. Please bear with me refactoring this codebase after CVPR dead

Xiao Han 33 Nov 24, 2022
[UNMAINTAINED] Automated machine learning for analytics & production

auto_ml Automated machine learning for production and analytics Installation pip install auto_ml Getting started from auto_ml import Predictor from au

Preston Parry 1.6k Jan 02, 2023