code for paper "Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?"

Overview

Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?

Code for paper:

Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?
Shen Yan, Yu Zheng, Wei Ao, Xiao Zeng, Mi Zhang.
NeurIPS 2020.

arch2vec
Top: The supervision signal for representation learning comes from the accuracies of architectures selected by the search strategies. Bottom (ours): Disentangling architecture representation learning and architecture search through unsupervised pre-training.

The repository is built upon pytorch_geometric, pybnn, nas_benchmarks, bananas.

1. Requirements

  • NVIDIA GPU, Linux, Python3
pip install -r requirements.txt

2. Experiments on NAS-Bench-101

Dataset preparation on NAS-Bench-101

Install nasbench and download nasbench_only108.tfrecord under ./data folder.

python preprocessing/gen_json.py

Data will be saved in ./data/data.json.

Pretraining

bash models/pretraining_nasbench101.sh

The pretrained model will be saved in ./pretrained/dim-16/.

arch2vec extraction

bash run_scripts/extract_arch2vec.sh

The extracted arch2vec will be saved in ./pretrained/dim-16/.

Alternatively, you can download the pretrained arch2vec on NAS-Bench-101.

Run experiments of RL search on NAS-Bench-101

bash run_scripts/run_reinforce_supervised.sh 
bash run_scripts/run_reinforce_arch2vec.sh 

Search results will be saved in ./saved_logs/rl/dim16

Generate json file:

python plot_scripts/plot_reinforce_search_arch2vec.py 

Run experiments of BO search on NAS-Bench-101

bash run_scripts/run_dngo_supervised.sh 
bash run_scripts/run_dngo_arch2vec.sh 

Search results will be saved in ./saved_logs/bo/dim16.

Generate json file:

python plot_scripts/plot_dngo_search_arch2vec.py

Plot NAS comparison curve on NAS-Bench-101:

python plot_scipts/plot_nasbench101_comparison.py

Plot CDF comparison curve on NAS-Bench-101:

Download the search results from search_logs.

python plot_scripts/plot_cdf.py

3. Experiments on NAS-Bench-201

Dataset preparation

Download the NAS-Bench-201-v1_0-e61699.pth under ./data folder.

python preprocessing/nasbench201_json.py

Data corresponding to the three datasets in NAS-Bench-201 will be saved in folder ./data/ as cifar10_valid_converged.json, cifar100.json, ImageNet16_120.json.

Pretraining

bash models/pretraining_nasbench201.sh

The pretrained model will be saved in ./pretrained/dim-16/.

Note that the pretrained model is shared across the 3 datasets in NAS-Bench-201.

arch2vec extraction

bash run_scripts/extract_arch2vec_nasbench201.sh

The extracted arch2vec will be saved in ./pretrained/dim-16/ as cifar10_valid_converged-arch2vec.pt, cifar100-arch2vec.pt and ImageNet16_120-arch2vec.pt.

Alternatively, you can download the pretrained arch2vec on NAS-Bench-201.

Run experiments of RL search on NAS-Bench-201

CIFAR-10: ./run_scripts/run_reinforce_arch2vec_nasbench201_cifar10_valid.sh
CIFAR-100: ./run_scripts/run_reinforce_arch2vec_nasbench201_cifar100.sh
ImageNet-16-120: ./run_scripts/run_reinforce_arch2vec_nasbench201_ImageNet.sh

Run experiments of BO search on NAS-Bench-201

CIFAR-10: ./run_scripts/run_bo_arch2vec_nasbench201_cifar10_valid.sh
CIFAR-100: ./run_scripts/run_bo_arch2vec_nasbench201_cifar100.sh
ImageNet-16-120: ./run_scripts/run_bo_arch2vec_nasbench201_ImageNet.sh

Summarize search result on NAS-Bench-201

python ./plot_scripts/summarize_nasbench201.py

The corresponding table will be printed to the console.

4. Experiments on DARTS Search Space

CIFAR-10 can be automatically downloaded by torchvision, ImageNet needs to be manually downloaded (preferably to a SSD) from http://image-net.org/download.

Random sampling 600,000 isomorphic graphs in DARTS space

python preprocessing/gen_isomorphism_graphs.py

Data will be saved in ./data/data_darts_counter600000.json.

Alternatively, you can download the extracted data_darts_counter600000.json.

Pretraining

bash models/pretraining_darts.sh

The pretrained model is saved in ./pretrained/dim-16/.

arch2vec extraction

bash run_scripts/extract_arch2vec_darts.sh

The extracted arch2vec will be saved in ./pretrained/dim-16/arch2vec-darts.pt.

Alternatively, you can download the pretrained arch2vec on DARTS search space.

Run experiments of RL search on DARTS search space

bash run_scripts/run_reinforce_arch2vec_darts.sh

logs will be saved in ./darts-rl/.

Final search result will be saved in ./saved_logs/rl/dim16.

Run experiments of BO search on DARTS search space

bash run_scripts/run_bo_arch2vec_darts.sh

logs will be saved in ./darts-bo/ .

Final search result will be saved in ./saved_logs/bo/dim16.

Evaluate the learned cell on DARTS Search Space on CIFAR-10

python darts/cnn/train.py --auxiliary --cutout --arch arch2vec_rl --seed 1
python darts/cnn/train.py --auxiliary --cutout --arch arch2vec_bo --seed 1
  • Expected results (RL): 2.60% test error with 3.3M model params.
  • Expected results (BO): 2.48% test error with 3.6M model params.

Transfer learning on ImageNet

python darts/cnn/train_imagenet.py  --arch arch2vec_rl --seed 1 
python darts/cnn/train_imagenet.py  --arch arch2vec_bo --seed 1
  • Expected results (RL): 25.8% test error with 4.8M model params and 533M mult-adds.
  • Expected results (RL): 25.5% test error with 5.2M model params and 580M mult-adds.

Visualize the learned cell

python darts/cnn/visualize.py arch2vec_rl
python darts/cnn/visualize.py arch2vec_bo

5. Analyzing the results

Visualize a sequence of decoded cells from the latent space

Download pretrained supervised embeddings of nasbench101 and nasbench201.

bash plot_scripts/drawfig5-nas101.sh # visualization on nasbench-101
bash plot_scripts/drawfig5-nas201.sh # visualization on nasbench-201
bash plot_scripts/drawfig5-darts.sh  # visualization on darts

The plots will be saved in ./graphvisualization.

Plot distribution of L2 distance by edit distance

Install nas_benchmarks and download nasbench_full.tfrecord under the same directory.

python plot_scripts/distance_comparison_fig3.py

Latent space 2D visualization

bash plot_scripts/drawfig4.sh

the plots will be saved in ./density.

Predictive performance comparison

Download predicted_accuracy under saved_logs/.

python plot_scripts/pearson_plot_fig2.py

Citation

If you find this useful for your work, please consider citing:

@InProceedings{yan2020arch,
  title = {Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?},
  author = {Yan, Shen and Zheng, Yu and Ao, Wei and Zeng, Xiao and Zhang, Mi},
  booktitle = {NeurIPS},
  year = {2020}
}
Repo for EchoVPR: Echo State Networks for Visual Place Recognition

EchoVPR Repo for EchoVPR: Echo State Networks for Visual Place Recognition Currently under development Dirs: data: pre-collected hidden representation

Anil Ozdemir 4 Oct 04, 2022
Large-scale Hyperspectral Image Clustering Using Contrastive Learning, CIKM 21 Workshop

Spectral-spatial contrastive clustering (SSCC) Yaoming Cai, Yan Liu, Zijia Zhang, Zhihua Cai, and Xiaobo Liu, Large-scale Hyperspectral Image Clusteri

Yaoming Cai 4 Nov 02, 2022
Pytorch implementation for Patient Knowledge Distillation for BERT Model Compression

Patient Knowledge Distillation for BERT Model Compression Knowledge distillation for BERT model Installation Run command below to install the environm

Siqi 180 Dec 19, 2022
High level network definitions with pre-trained weights in TensorFlow

TensorNets High level network definitions with pre-trained weights in TensorFlow (tested with 2.1.0 = TF = 1.4.0). Guiding principles Applicability.

Taehoon Lee 1k Dec 13, 2022
This is an official implementation for "DeciWatch: A Simple Baseline for 10x Efficient 2D and 3D Pose Estimation"

DeciWatch: A Simple Baseline for 10× Efficient 2D and 3D Pose Estimation This repo is the official implementation of "DeciWatch: A Simple Baseline for

117 Dec 24, 2022
Quasi-Dense Similarity Learning for Multiple Object Tracking, CVPR 2021 (Oral)

Quasi-Dense Tracking This is the offical implementation of paper Quasi-Dense Similarity Learning for Multiple Object Tracking. We present a trailer th

ETH VIS Research Group 327 Dec 27, 2022
Official implementation of Meta-StyleSpeech and StyleSpeech

Meta-StyleSpeech : Multi-Speaker Adaptive Text-to-Speech Generation Dongchan Min, Dong Bok Lee, Eunho Yang, and Sung Ju Hwang This is an official code

min95 168 Dec 28, 2022
disentanglement_lib is an open-source library for research on learning disentangled representations.

disentanglement_lib disentanglement_lib is an open-source library for research on learning disentangled representation. It supports a variety of diffe

Google Research 1.3k Dec 28, 2022
A deep learning CNN model to identify and classify and check if a person is wearing a mask or not.

Face Mask Detection The Model is designed to check if any human is wearing a mask or not. Dataset Description The Dataset contains a total of 11,792 i

1 Mar 01, 2022
Applying CLIP to Point Cloud Recognition.

PointCLIP: Point Cloud Understanding by CLIP This repository is an official implementation of the paper 'PointCLIP: Point Cloud Understanding by CLIP'

Renrui Zhang 175 Dec 24, 2022
CMUA-Watermark: A Cross-Model Universal Adversarial Watermark for Combating Deepfakes (AAAI2022)

CMUA-Watermark The official code for CMUA-Watermark: A Cross-Model Universal Adversarial Watermark for Combating Deepfakes (AAAI2022) arxiv. It is bas

50 Nov 26, 2022
A scikit-learn-compatible module for estimating prediction intervals.

|Anaconda|_ MAPIE - Model Agnostic Prediction Interval Estimator MAPIE allows you to easily estimate prediction intervals using your favourite sklearn

SimAI 584 Dec 27, 2022
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Swin Transformer 1.4k Dec 30, 2022
Code for paper: Group-CAM: Group Score-Weighted Visual Explanations for Deep Convolutional Networks

Group-CAM By Zhang, Qinglong and Rao, Lu and Yang, Yubin [State Key Laboratory for Novel Software Technology at Nanjing University] This repo is the o

zhql 98 Nov 16, 2022
A python library for time-series smoothing and outlier detection in a vectorized way.

tsmoothie A python library for time-series smoothing and outlier detection in a vectorized way. Overview tsmoothie computes, in a fast and efficient w

Marco Cerliani 517 Dec 28, 2022
Repo for parser tensorflow(.pb) and tflite(.tflite)

tfmodel_parser .pb file is the format of tensorflow model .tflite file is the format of tflite model, which usually used in mobile devices before star

1 Dec 23, 2021
Tzer: TVM Implementation of "Coverage-Guided Tensor Compiler Fuzzing with Joint IR-Pass Mutation (OOPSLA'22)“.

Artifact • Reproduce Bugs • Quick Start • Installation • Extend Tzer Coverage-Guided Tensor Compiler Fuzzing with Joint IR-Pass Mutation This is the s

12 Dec 29, 2022
Sign-to-Speech for Sign Language Understanding: A case study of Nigerian Sign Language

Sign-to-Speech for Sign Language Understanding: A case study of Nigerian Sign Language This repository contains the code, model, and deployment config

16 Oct 23, 2022
[ACM MM 2021] Yes, "Attention is All You Need", for Exemplar based Colorization

Transformer for Image Colorization This is an implemention for Yes, "Attention Is All You Need", for Exemplar based Colorization, and the current soft

Wang Yin 30 Dec 07, 2022
Single Red Blood Cell Hydrodynamic Traps Via the Generative Design

Rbc-traps-generative-design - The generative design for single red clood cell hydrodynamic traps using GEFEST framework

Natural Systems Simulation Lab 4 Jun 16, 2022