DrNAS: Dirichlet Neural Architecture Search

Related tags

Deep LearningDrNAS
Overview

DrNAS

About

Code accompanying the paper
ICLR'2021: DrNAS: Dirichlet Neural Architecture Search paper
Xiangning Chen*, Ruochen Wang*, Minhao Cheng*, Xiaocheng Tang, Cho-Jui Hsieh

This code is based on the implementation of NAS-Bench-201 and PC-DARTS.

This paper proposes a novel differentiable architecture search method by formulating it into a distribution learning problem. We treat the continuously relaxed architecture mixing weight as random variables, modeled by Dirichlet distribution. With recently developed pathwise derivatives, the Dirichlet parameters can be easily optimized with gradient-based optimizer in an end-to-end manner. This formulation improves the generalization ability and induces stochasticity that naturally encourages exploration in the search space. Furthermore, to alleviate the large memory consumption of differentiable NAS, we propose a simple yet effective progressive learning scheme that enables searching directly on large-scale tasks, eliminating the gap between search and evaluation phases. Extensive experiments demonstrate the effectiveness of our method. Specifically, we obtain a test error of 2.46% for CIFAR-10, 23.7% for ImageNet under the mobile setting. On NAS-Bench-201, we also achieve state-of-the-art results on all three datasets and provide insights for the effective design of neural architecture search algorithms.

Results

On NAS-Bench-201

The table below shows the test accuracy on NAS-Bench-201 space. We achieve the state-of-the-art results on all three datasets. On CIFAR-100, DrNAS even achieves the global optimal with no variance!

Method CIFAR-10 (test) CIFAR-100 (test) ImageNet-16-120 (test)
ENAS 54.30 ± 0.00 10.62 ± 0.27 16.32 ± 0.00
DARTS 54.30 ± 0.00 38.97 ± 0.00 18.41 ± 0.00
SNAS 92.77 ± 0.83 69.34 ± 1.98 43.16 ± 2.64
PC-DARTS 93.41 ± 0.30 67.48 ± 0.89 41.31 ± 0.22
DrNAS (ours) 94.36 ± 0.00 73.51 ± 0.00 46.34 ± 0.00
optimal 94.37 73.51 47.31

For every search process, we sample 100 architectures from the current Dirichlet distribution and plot their accuracy range along with the current architecture selected by Dirichlet mean (solid line). The figure below shows that the accuracy range of the sampled architectures starts very wide but narrows gradually during the search phase. It indicates that DrNAS learns to encourage exploration at the early stages and then gradually reduces it towards the end as the algorithm becomes more and more confident of the current choice. Moreover, the performance of our architectures can consistently match the best performance of the sampled architectures, indicating the effectiveness of DrNAS.

On DARTS Space (CIFAR-10)

DrNAS achieves an average test error of 2.46%, ranking top amongst recent NAS results.

Method Test Error (%) Params (M) Search Cost (GPU days)
ENAS 2.89 4.6 0.5
DARTS 2.76 ± 0.09 3.3 1.0
SNAS 2.85 ± 0.02 2.8 1.5
PC-DARTS 2.57 ± 0.07 3.6 0.1
DrNAS (ours) 2.46 ± 0.03 4.1 0.6

On DARTS Space (ImageNet)

DrNAS can perform a direct search on ImageNet and achieves a top-1 test error below 24.0%!

Method Top-1 Error (%) Params (M) Search Cost (GPU days)
DARTS* 26.7 4.7 1.0
SNAS* 27.3 4.3 1.5
PC-DARTS 24.2 5.3 3.8
DSNAS 25.7 - -
DrNAS (ours) 23.7 5.7 4.6

* not a direct search

Usage

Architecture Search

Search on NAS-Bench-201 Space: (3 datasets to choose from)

  • Data preparation: Please first download the 201 benchmark file and prepare the api follow this repository.

  • cd 201-space && python train_search.py

  • With Progressively Pruning: cd 201-space && python train_search_progressive.py

Search on DARTS Space:

  • Data preparation: For a direct search on ImageNet, we follow PC-DARTS to sample 10% and 2.5% images for earch class as train and validation.

  • CIFAR-10: cd DARTS-space && python train_search.py

  • ImageNet: cd DARTS-space && python train_search_imagenet.py

Architecture Evaluation

  • CIFAR-10: cd DARTS-space && python train.py --cutout --auxiliary

  • ImageNet: cd DARTS-space && python train_imagenet.py --auxiliary

Reference

If you find this code useful in your research please cite

@inproceedings{chen2021drnas,
    title={Dr{\{}NAS{\}}: Dirichlet Neural Architecture Search},
    author={Xiangning Chen and Ruochen Wang and Minhao Cheng and Xiaocheng Tang and Cho-Jui Hsieh},
    booktitle={International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=9FWas6YbmB3}
}

Related Publications

Owner
Xiangning Chen
UCLA CS Ph.D. Student
Xiangning Chen
Picasso: a methods for embedding points in 2D in a way that respects distances while fitting a user-specified shape.

Picasso Code to generate Picasso embeddings of any input matrix. Picasso maps the points of an input matrix to user-defined, n-dimensional shape coord

Pachter Lab 45 Dec 23, 2022
Mercury: easily convert Python notebook to web app and share with others

Mercury Share your Python notebooks with others Easily convert your Python notebooks into interactive web apps by adding parameters in YAML. Simply ad

MLJAR 2.2k Dec 27, 2022
Keyword-BERT: Keyword-Attentive Deep Semantic Matching

project discription An implementation of the Keyword-BERT model mentioned in my paper Keyword-Attentive Deep Semantic Matching (Plz cite this github r

1 Nov 14, 2021
Multi-task head pose estimation in-the-wild

Multi-task head pose estimation in-the-wild We provide C++ code in order to replicate the head-pose experiments in our paper https://ieeexplore.ieee.o

Roberto Valle 26 Oct 06, 2022
Manim is an engine for precise programmatic animations, designed for creating explanatory math videos

Manim is an engine for precise programmatic animations, designed for creating explanatory math videos. Note, there are two versions of manim. This rep

Grant Sanderson 49k Jan 09, 2023
Mall-Customers-Segmentation - Customer Segmentation Using K-Means Clustering

Overview Customer Segmentation is one the most important applications of unsupervised learning. Using clustering techniques, companies can identify th

NelakurthiSudheer 2 Jan 03, 2022
Unofficial keras(tensorflow) implementation of MAE model from Masked Autoencoders Are Scalable Vision Learners

MAE-keras Unofficial keras(tensorflow) implementation of MAE model described in 'Masked Autoencoders Are Scalable Vision Learners'. This work has been

Yewon 11 Jun 12, 2022
Public implementation of "Learning from Suboptimal Demonstration via Self-Supervised Reward Regression" from CoRL'21

Self-Supervised Reward Regression (SSRR) Codebase for CoRL 2021 paper "Learning from Suboptimal Demonstration via Self-Supervised Reward Regression "

19 Dec 12, 2022
Repository containing the PhD Thesis "Formal Verification of Deep Reinforcement Learning Agents"

Getting Started This repository contains the code used for the following publications: Probabilistic Guarantees for Safe Deep Reinforcement Learning (

Edoardo Bacci 5 Aug 31, 2022
Hide screen when boss is approaching.

BossSensor Hide your screen when your boss is approaching. Demo The boss stands up. He is approaching. When he is approaching, the program fetches fac

Hiroki Nakayama 6.2k Jan 07, 2023
A simple consistency training framework for semi-supervised image semantic segmentation

PseudoSeg: Designing Pseudo Labels for Semantic Segmentation PseudoSeg is a simple consistency training framework for semi-supervised image semantic s

Google Interns 143 Dec 13, 2022
Monk is a low code Deep Learning tool and a unified wrapper for Computer Vision.

Monk - A computer vision toolkit for everyone Why use Monk Issue: Want to begin learning computer vision Solution: Start with Monk's hands-on study ro

Tessellate Imaging 507 Dec 04, 2022
Dilated RNNs in pytorch

PyTorch Dilated Recurrent Neural Networks PyTorch implementation of Dilated Recurrent Neural Networks (DilatedRNN). Getting Started Installation: $ pi

Zalando Research 200 Nov 17, 2022
[WWW 2021] Source code for "Graph Contrastive Learning with Adaptive Augmentation"

GCA Source code for Graph Contrastive Learning with Adaptive Augmentation (WWW 2021) For example, to run GCA-Degree under WikiCS, execute: python trai

Big Data and Multi-modal Computing Group, CRIPAC 97 Jan 07, 2023
Code to reproduce experiments in the paper "Explainability Requires Interactivity".

Explainability Requires Interactivity This repository contains the code to train all custom models used in the paper Explainability Requires Interacti

Digital Health & Machine Learning 5 Apr 07, 2022
Proto-RL: Reinforcement Learning with Prototypical Representations

Proto-RL: Reinforcement Learning with Prototypical Representations This is a PyTorch implementation of Proto-RL from Reinforcement Learning with Proto

Denis Yarats 74 Dec 06, 2022
Author's PyTorch implementation of TD3 for OpenAI gym tasks

Addressing Function Approximation Error in Actor-Critic Methods PyTorch implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3). If y

Scott Fujimoto 1.3k Dec 25, 2022
Official repository for Few-shot Image Generation via Cross-domain Correspondence (CVPR '21)

Few-shot Image Generation via Cross-domain Correspondence Utkarsh Ojha, Yijun Li, Jingwan Lu, Alexei A. Efros, Yong Jae Lee, Eli Shechtman, Richard Zh

Utkarsh Ojha 251 Dec 11, 2022
An implementation of EWC with PyTorch

EWC.pytorch An implementation of Elastic Weight Consolidation (EWC), proposed in James Kirkpatrick et al. Overcoming catastrophic forgetting in neural

Ryuichiro Hataya 166 Dec 22, 2022
An implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep Neural Networks in PyTorch.

Neural Attention Distillation This is an implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep

Yige-Li 84 Jan 04, 2023