DrNAS: Dirichlet Neural Architecture Search

Related tags

Deep LearningDrNAS
Overview

DrNAS

About

Code accompanying the paper
ICLR'2021: DrNAS: Dirichlet Neural Architecture Search paper
Xiangning Chen*, Ruochen Wang*, Minhao Cheng*, Xiaocheng Tang, Cho-Jui Hsieh

This code is based on the implementation of NAS-Bench-201 and PC-DARTS.

This paper proposes a novel differentiable architecture search method by formulating it into a distribution learning problem. We treat the continuously relaxed architecture mixing weight as random variables, modeled by Dirichlet distribution. With recently developed pathwise derivatives, the Dirichlet parameters can be easily optimized with gradient-based optimizer in an end-to-end manner. This formulation improves the generalization ability and induces stochasticity that naturally encourages exploration in the search space. Furthermore, to alleviate the large memory consumption of differentiable NAS, we propose a simple yet effective progressive learning scheme that enables searching directly on large-scale tasks, eliminating the gap between search and evaluation phases. Extensive experiments demonstrate the effectiveness of our method. Specifically, we obtain a test error of 2.46% for CIFAR-10, 23.7% for ImageNet under the mobile setting. On NAS-Bench-201, we also achieve state-of-the-art results on all three datasets and provide insights for the effective design of neural architecture search algorithms.

Results

On NAS-Bench-201

The table below shows the test accuracy on NAS-Bench-201 space. We achieve the state-of-the-art results on all three datasets. On CIFAR-100, DrNAS even achieves the global optimal with no variance!

Method CIFAR-10 (test) CIFAR-100 (test) ImageNet-16-120 (test)
ENAS 54.30 ± 0.00 10.62 ± 0.27 16.32 ± 0.00
DARTS 54.30 ± 0.00 38.97 ± 0.00 18.41 ± 0.00
SNAS 92.77 ± 0.83 69.34 ± 1.98 43.16 ± 2.64
PC-DARTS 93.41 ± 0.30 67.48 ± 0.89 41.31 ± 0.22
DrNAS (ours) 94.36 ± 0.00 73.51 ± 0.00 46.34 ± 0.00
optimal 94.37 73.51 47.31

For every search process, we sample 100 architectures from the current Dirichlet distribution and plot their accuracy range along with the current architecture selected by Dirichlet mean (solid line). The figure below shows that the accuracy range of the sampled architectures starts very wide but narrows gradually during the search phase. It indicates that DrNAS learns to encourage exploration at the early stages and then gradually reduces it towards the end as the algorithm becomes more and more confident of the current choice. Moreover, the performance of our architectures can consistently match the best performance of the sampled architectures, indicating the effectiveness of DrNAS.

On DARTS Space (CIFAR-10)

DrNAS achieves an average test error of 2.46%, ranking top amongst recent NAS results.

Method Test Error (%) Params (M) Search Cost (GPU days)
ENAS 2.89 4.6 0.5
DARTS 2.76 ± 0.09 3.3 1.0
SNAS 2.85 ± 0.02 2.8 1.5
PC-DARTS 2.57 ± 0.07 3.6 0.1
DrNAS (ours) 2.46 ± 0.03 4.1 0.6

On DARTS Space (ImageNet)

DrNAS can perform a direct search on ImageNet and achieves a top-1 test error below 24.0%!

Method Top-1 Error (%) Params (M) Search Cost (GPU days)
DARTS* 26.7 4.7 1.0
SNAS* 27.3 4.3 1.5
PC-DARTS 24.2 5.3 3.8
DSNAS 25.7 - -
DrNAS (ours) 23.7 5.7 4.6

* not a direct search

Usage

Architecture Search

Search on NAS-Bench-201 Space: (3 datasets to choose from)

  • Data preparation: Please first download the 201 benchmark file and prepare the api follow this repository.

  • cd 201-space && python train_search.py

  • With Progressively Pruning: cd 201-space && python train_search_progressive.py

Search on DARTS Space:

  • Data preparation: For a direct search on ImageNet, we follow PC-DARTS to sample 10% and 2.5% images for earch class as train and validation.

  • CIFAR-10: cd DARTS-space && python train_search.py

  • ImageNet: cd DARTS-space && python train_search_imagenet.py

Architecture Evaluation

  • CIFAR-10: cd DARTS-space && python train.py --cutout --auxiliary

  • ImageNet: cd DARTS-space && python train_imagenet.py --auxiliary

Reference

If you find this code useful in your research please cite

@inproceedings{chen2021drnas,
    title={Dr{\{}NAS{\}}: Dirichlet Neural Architecture Search},
    author={Xiangning Chen and Ruochen Wang and Minhao Cheng and Xiaocheng Tang and Cho-Jui Hsieh},
    booktitle={International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=9FWas6YbmB3}
}

Related Publications

Owner
Xiangning Chen
UCLA CS Ph.D. Student
Xiangning Chen
Generic Foreground Segmentation in Images

Pixel Objectness The following repository contains pretrained model for pixel objectness. Please visit our project page for the paper and visual resul

Suyog Jain 157 Nov 21, 2022
Multi Task RL Baselines

MTRL Multi Task RL Algorithms Contents Introduction Setup Usage Documentation Contributing to MTRL Community Acknowledgements Introduction M

Facebook Research 171 Jan 09, 2023
SigOpt wrappers for scikit-learn methods

SigOpt + scikit-learn Interfacing This package implements useful interfaces and wrappers for using SigOpt and scikit-learn together Getting Started In

SigOpt 73 Sep 30, 2022
Motion planning algorithms commonly used on autonomous vehicles. (path planning + path tracking)

Overview This repository implemented some common motion planners used on autonomous vehicles, including Hybrid A* Planner Frenet Optimal Trajectory Hi

Huiming Zhou 1k Jan 09, 2023
Not Suitable for Work (NSFW) classification using deep neural network Caffe models.

Open nsfw model This repo contains code for running Not Suitable for Work (NSFW) classification deep neural network Caffe models. Please refer our blo

Yahoo 5.6k Jan 05, 2023
BboxToolkit is a tiny library of special bounding boxes.

BboxToolkit is a light codebase collecting some practical functions for the special-shape detection, such as oriented detection

jbwang1997 73 Jan 01, 2023
Code for Understanding Pooling in Graph Neural Networks

Select, Reduce, Connect This repository contains the code used for the experiments of: "Understanding Pooling in Graph Neural Networks" Setup Install

Daniele Grattarola 37 Dec 13, 2022
ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models (ICCV 2021 Oral)

ILVR + ADM This is the implementation of ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models (ICCV 2021 Oral). This repository is h

Jooyoung Choi 225 Dec 28, 2022
This repo contains implementation of different architectures for emotion recognition in conversations.

Emotion Recognition in Conversations Updates 🔥 🔥 🔥 Date Announcements 03/08/2021 🎆 🎆 We have released a new dataset M2H2: A Multimodal Multiparty

Deep Cognition and Language Research (DeCLaRe) Lab 1k Dec 30, 2022
AI that generate music

PianoGPT ai that generate music try it here https://share.streamlit.io/annasajkh/pianogpt/main/main.py or here https://huggingface.co/spaces/Annas/Pia

Annas 28 Nov 27, 2022
IEGAN — Official PyTorch Implementation Independent Encoder for Deep Hierarchical Unsupervised Image-to-Image Translation

IEGAN — Official PyTorch Implementation Independent Encoder for Deep Hierarchical Unsupervised Image-to-Image Translation Independent Encoder for Deep

30 Nov 05, 2022
Code for the Paper: Conditional Variational Capsule Network for Open Set Recognition

Conditional Variational Capsule Network for Open Set Recognition This repository hosts the official code related to "Conditional Variational Capsule N

Guglielmo Camporese 35 Nov 21, 2022
Research on controller area network Intrusion Detection Systems

Group members information Member 1: Lixue Liang Member 2: Yuet Lee Chan Member 3: Xinruo Zhang Member 4: Yifei Han User Manual Generate Attack Packets

Roche 4 Aug 30, 2022
Tightness-aware Evaluation Protocol for Scene Text Detection

TIoU-metric Release on 27/03/2019. This repository is built on the ICDAR 2015 evaluation code. If you propose a better metric and require further eval

Yuliang Liu 206 Nov 18, 2022
[NeurIPS-2021] Slow Learning and Fast Inference: Efficient Graph Similarity Computation via Knowledge Distillation

Efficient Graph Similarity Computation - (EGSC) This repo contains the source code and dataset for our paper: Slow Learning and Fast Inference: Effici

24 Dec 31, 2022
Blind visual quality assessment on 360° Video based on progressive learning

Blind visual quality assessment on omnidirectional or 360 video (ProVQA) Blind VQA for 360° Video via Progressively Learning from Pixels, Frames and V

5 Jan 06, 2023
Kindle is an easy model build package for PyTorch.

Kindle is an easy model build package for PyTorch. Building a deep learning model became so simple that almost all model can be made by copy and paste from other existing model codes. So why code? wh

Jongkuk Lim 77 Nov 11, 2022
Iris prediction model is used to classify iris species created julia's DecisionTree, DataFrames, JLD2, PlotlyJS and Statistics packages.

Iris Species Predictor Iris prediction is used to classify iris species using their sepal length, sepal width, petal length and petal width created us

Siva Prakash 2 Jan 06, 2022
A Comparative Framework for Multimodal Recommender Systems

Cornac Cornac is a comparative framework for multimodal recommender systems. It focuses on making it convenient to work with models leveraging auxilia

Preferred.AI 671 Jan 03, 2023
Use unsupervised and supervised learning to predict stocks

AIAlpha: Multilayer neural network architecture for stock return prediction This project is meant to be an advanced implementation of stacked neural n

Vivek Palaniappan 1.5k Jan 06, 2023