CATE: Computation-aware Neural Architecture Encoding with Transformers

Overview

CATE: Computation-aware Neural Architecture Encoding with Transformers

Code for paper:

CATE: Computation-aware Neural Architecture Encoding with Transformers
Shen Yan, Kaiqiang Song, Fei Liu, Mi Zhang.
ICML 2021 (Long Talk).

CATE
Overview of CATE: It takes computationally similar architecture pairs as the input and trained to predict masked operators given the pairwise computation information. Apart from the cross-attention blocks, the pretrained Transformer encoder is used to extract architecture encodings for the downstream search.

The repository is built upon pybnn and nas-encodings.

Requirements

conda create -n tf python=3.7
source activate tf
cat requirements.txt | xargs -n 1 -L 1 pip install

Experiments on NAS-Bench-101

Dataset preparation on NAS-Bench-101

Install nasbench and download nasbench_only108.tfrecord in ./data folder.

python preprocessing/gen_json.py

Data will be saved in ./data/nasbench101.json.

Generate architecture pairs

python preprocessing/data_generate.py --dataset nasbench101 --flag extract_seq
python preprocessing/data_generate.py --dataset nasbench101 --flag build_pair --k 2 --d 2000000 --metric params

The corresponding training data and pairs will be saved in ./data/nasbench101/.

Alternatively, you can download the data train_data.pt, test_data.pt and pair indices train_pair_k2_d2000000_metric_params.pt, test_pair_k2_d2000000_metric_params.pt from here.

Pretraining

bash run_scripts/pretrain_nasbench101.sh

The pretrained models will be saved in ./model/.

Alternatively, you can download the pretrained model nasbench101_model_best.pth from here.

Extract the pretrained encodings

python inference/inference.py --pretrained_path model/nasbench101_model_best.pth.tar --train_data data/nasbench101/train_data.pt --valid_data data/nasbench101/test_data.pt --dataset nasbench101

The extracted embeddings will be saved in ./cate_nasbench101.pt.

Alternatively, you can download the pretrained embeddings cate_nasbench101.pt from here.

Run search experiments on NAS-Bench-101

bash run_scripts/run_search_nasbench101.sh

Search results will be saved in ./nasbench101/.

Experiments on NAS-Bench-301

Dataset preparation

Install nasbench301 and download the xgb_v1.0 and lgb_runtime_v1.0 file. You may need to make pytorch_geometric compatible with Pytorch and CUDA version.

python preprocessing/gen_json_darts.py # randomly sample 1,000,000 archs

Data will be saved in ./data/nasbench301_proxy.json.

Alternatively, you can download the json file nasbench301_proxy.json from here.

Generate architecture pairs

python preprocessing/data_generate.py --dataset nasbench301 --flag extract_seq
python preprocessing/data_generate.py --dataset nasbench301 --flag build_pair --k 1 --d 5000000 --metric flops

The correspoding training data and pairs will be saved in ./data/nasbench301/.

Alternatively, you can download the data train_data.pt, test_data.pt and pair indices train_pair_k1_d5000000_metric_flops.pt, test_pair_k1_d5000000_metric_flops.pt from here.

Pretraining

bash run_scripts/pretrain_nasbench301.sh

The pretrained models will be saved in ./model/.

Alternatively, you can download the pretrained model nasbench301_model_best.pth from here.

Extract the pretrained encodings

python inference/inference.py --pretrained_path model/nasbench301_model_best.pth.tar --train_data data/nasbench301/train_data.pt --valid_data data/nasbench301/test_data.pt --dataset nasbench301 --n_vocab 11

The extracted encodings will be saved in ./cate_nasbench301.pt.

Alternatively, you can download the pretrained embeddings cate_nasbench301.pt from here.

Run search experiments on NAS-Bench-301

bash run_scripts/run_search_nasbench301.sh

Search results will be saved in ./nasbench301/.

DARTS experiments without surrogate models

Download the pretrained embeddings cate_darts.pt from here.

python search_methods/dngo_ls_darts.py --dim 64 --init_size 16 --topk 5 --dataset darts --output_path bo  --embedding_path cate_darts.pt

Search log will be saved in ./darts/. Final search result will be saved in ./darts/bo/dim64.

Evaluate the learned cell on DARTS Search Space on CIFAR-10

python darts/cnn/train.py --auxiliary --cutout --arch cate_small
python darts/cnn/train.py --auxiliary --cutout --arch cate_large
  • Expected results (CATE-Small): 2.55% avg. test error with 3.5M model params.
  • Expected results (CATE-Large): 2.46% avg. test error with 4.1M model params.

Transfer learning on ImageNet

python darts/cnn/train_imagenet.py  --arch cate_small --seed 1 
python darts/cnn/train_imagenet.py  --arch cate_large --seed 1
  • Expected results (CATE-Small): 26.05% test error with 5.0M model params and 556M mult-adds.
  • Expected results (CATE-Large): 25.01% test error with 5.8M model params and 642M mult-adds.

Visualize the learned cell

python darts/cnn/visualize.py cate_small
python darts/cnn/visualize.py cate_large

Experiments on outside search space

Build outside search space dataset

bash run_scripts/generate_oo.sh

Data will be saved in ./data/nasbench101_oo_train.json and ./data/nasbench101_oo_test.json.

Generate architecture pairs

python preprocessing/data_generate_oo.py --flag extract_seq
python preprocessing/data_generate_oo.py --flag build_pair

The corresponding training data and pair indices will be saved in ./data/nasbench101/.

Pretraining

python run.py --do_train --parallel --train_data data/nasbench101/nasbench101_oo_trainSet_train.pt --train_pair data/nasbench101/oo_train_pairs_k2_params_dist2e6.pt  --valid_data data/nasbench101/nasbench101_oo_trainSet_validation.pt --valid_pair data/nasbench101/oo_validation_pairs_k2_params_dist2e6.pt --dataset oo

The pretrained models will be saved in ./model/.

Extract embeddings on outside search space

# Adjacency encoding
python inference/inference_adj.py
# CATE encoding
python inference/inference.py --pretrained_path model/oo_model_best.pth.tar --train_data data/nasbench101/nasbench101_oo_testSet_split1.pt --valid_data data/nasbench101/nasbench101_oo_testSet_split2.pt --dataset oo_nasbench101

The extracted encodings will be saved as ./adj_oo_nasbench101.pt and ./cate_oo_nasbench101.pt.

Alternatively, you can download the data, pair indices, pretrained models, and extracted embeddings from here.

Run MLP predictor experiments on outside search space

for s in {1..500}; do python search_methods/oo_mlp.py --dim 27 --seed $s --init_size 16 --topk 5 --dataset oo_nasbench101 --output_path np_adj  --embedding_path adj_oo_nasbench101.pt; done
for s in {1..500}; do python search_methods/oo_mlp.py --dim 64 --seed $s --init_size 16 --topk 5 --dataset oo_nasbench101 --output_path np_cate  --embedding_path cate_oo_nasbench101.pt; done

Search results will be saved in ./oo_nasbench101.

Citation

If you find this useful for your work, please consider citing:

@InProceedings{yan2021cate,
  title = {CATE: Computation-aware Neural Architecture Encoding with Transformers},
  author = {Yan, Shen and Song, Kaiqiang and Liu, Fei and Zhang, Mi},
  booktitle = {ICML},
  year = {2021}
}
Single-Stage Instance Shadow Detection with Bidirectional Relation Learning (CVPR 2021 Oral)

Single-Stage Instance Shadow Detection with Bidirectional Relation Learning (CVPR 2021 Oral) Tianyu Wang*, Xiaowei Hu*, Chi-Wing Fu, and Pheng-Ann Hen

Steve Wong 51 Oct 20, 2022
Benchmark library for high-dimensional HPO of black-box models based on Weighted Lasso regression

LassoBench LassoBench is a library for high-dimensional hyperparameter optimization benchmarks based on Weighted Lasso regression. Note: LassoBench is

Kenan Šehić 5 Mar 15, 2022
Emotion Recognition from Facial Images

Reconhecimento de Emoções a partir de imagens faciais Este projeto implementa um classificador simples que utiliza técncias de deep learning e transfe

Gabriel 2 Feb 09, 2022
Node-level Graph Regression with Deep Gaussian Process Models

Node-level Graph Regression with Deep Gaussian Process Models Prerequests our implementation is mainly based on tensorflow 1.x and gpflow 1.x: python

1 Jan 16, 2022
A tensorflow implementation of GCN-LPA

GCN-LPA This repository is the implementation of GCN-LPA (arXiv): Unifying Graph Convolutional Neural Networks and Label Propagation Hongwei Wang, Jur

Hongwei Wang 83 Nov 28, 2022
Machine Learning in Asset Management (by @firmai)

Machine Learning in Asset Management If you like this type of content then visit ML Quant site below: https://www.ml-quant.com/ Part One Follow this l

Derek Snow 1.5k Jan 02, 2023
The code for two papers: Feedback Transformer and Expire-Span.

transformer-sequential This repo contains the code for two papers: Feedback Transformer Expire-Span The training code is structured for long sequentia

Facebook Research 125 Dec 25, 2022
An end-to-end implementation of intent prediction with Metaflow and other cool tools

You Don't Need a Bigger Boat An end-to-end (Metaflow-based) implementation of an intent prediction flow for kids who can't MLOps good and wanna learn

Jacopo Tagliabue 614 Dec 31, 2022
An exploration of log domain "alternative floating point" for hardware ML/AI accelerators.

This repository contains the SystemVerilog RTL, C++, HLS (Intel FPGA OpenCL to wrap RTL code) and Python needed to reproduce the numerical results in

Facebook Research 373 Dec 31, 2022
Time Dependent DFT in Tamm-Dancoff Approximation

Density Function Theory Program - kspy-tddft(tda) This is an implementation of Time-Dependent Density Functional Theory(TDDFT) using the Tamm-Dancoff

Peter Borthwick 2 Nov 17, 2022
Extract MNIST handwritten digits dataset binary file into bmp images

MNIST-dataset-extractor Extract MNIST handwritten digits dataset binary file into bmp images More info at http://yann.lecun.com/exdb/mnist/ Dependenci

Omar Mostafa 6 May 24, 2021
Official implementation of "Can You Spot the Chameleon? Adversarially Camouflaging Images from Co-Salient Object Detection" in CVPR 2022.

Jadena Official implementation of "Can You Spot the Chameleon? Adversarially Camouflaging Images from Co-Salient Object Detection" in CVPR 2022. arXiv

Qing Guo 13 Nov 29, 2022
In this project, we develop a face recognize platform based on MTCNN object-detection netcwork and FaceNet self-supervised network.

模式识别大作业——人脸检测与识别平台 本项目是一个简易的人脸检测识别平台,提供了人脸信息录入和人脸识别的功能。前端采用 html+css+js,后端采用 pytorch,

Xuhua Huang 5 Aug 02, 2022
Translate darknet to tensorflow. Load trained weights, retrain/fine-tune using tensorflow, export constant graph def to mobile devices

Intro Real-time object detection and classification. Paper: version 1, version 2. Read more about YOLO (in darknet) and download weight files here. In

Trieu 6.1k Dec 30, 2022
PyTorch implementation of "Dataset Knowledge Transfer for Class-Incremental Learning Without Memory" (WACV2022)

Dataset Knowledge Transfer for Class-Incremental Learning Without Memory [Paper] [Slides] Summary Introduction Installation Reproducing results Citati

Habib Slim 5 Dec 05, 2022
This is an example of object detection on Micro bacterium tuberculosis using Mask-RCNN

Mask-RCNN on Mycobacterium tuberculosis This is an example of object detection on Mycobacterium Tuberculosis using Mask RCNN. Implement of Mask R-CNN

Jun-En Ding 1 Sep 16, 2021
(JMLR' 19) A Python Toolbox for Scalable Outlier Detection (Anomaly Detection)

Python Outlier Detection (PyOD) Deployment & Documentation & Stats & License PyOD is a comprehensive and scalable Python toolkit for detecting outlyin

Yue Zhao 6.6k Jan 05, 2023
Code of Adverse Weather Image Translation with Asymmetric and Uncertainty aware GAN

Adverse Weather Image Translation with Asymmetric and Uncertainty-aware GAN (AU-GAN) Official Tensorflow implementation of Adverse Weather Image Trans

Jeong-gi Kwak 36 Dec 26, 2022
tensorflow implementation of 'YOLO : Real-Time Object Detection'

YOLO_tensorflow (Version 0.3, Last updated :2017.02.21) 1.Introduction This is tensorflow implementation of the YOLO:Real-Time Object Detection It can

Jinyoung Choi 1.7k Nov 21, 2022
The implementation of 'Image synthesis via semantic composition'.

Image synthesis via semantic synthesis [Project Page] by Yi Wang, Lu Qi, Ying-Cong Chen, Xiangyu Zhang, Jiaya Jia. Introduction This repository gives

DV Lab 71 Jan 06, 2023