Official implementation for "Symbolic Learning to Optimize: Towards Interpretability and Scalability"

Overview

Symbolic Learning to Optimize

This is the official implementation for ICLR-2022 paper "Symbolic Learning to Optimize: Towards Interpretability and Scalability"

Introduction

Recent studies on Learning to Optimize (L2O) suggest a promising path to automating and accelerating the optimization procedure for complicated tasks. Existing L2O models parameterize optimization rules by neural networks, and learn those numerical rules via meta-training. However, they face two common pitfalls: (1) scalability: the numerical rules represented by neural networks create extra memory overhead for applying L2O models, and limits their applicability to optimizing larger tasks; (2) interpretability: it is unclear what each L2O model has learned in its black-box optimization rule, nor is it straightforward to compare different L2O models in an explainable way. To avoid both pitfalls, this paper proves the concept that we can ``kill two birds by one stone'', by introducing the powerful tool of symbolic regression to L2O. In this paper, we establish a holistic symbolic representation and analysis framework for L2O, which yields a series of insights for learnable optimizers. Leveraging our findings, we further propose a lightweight L2O model that can be meta-trained on large-scale problems and outperformed human-designed and tuned optimizers. Our work is set to supply a brand-new perspective to L2O research.

Our approach:

First train a neural network (LSTM) based optimizer, then leverage the symbolic regression tool to trouble shoot and analyze the neural network based optimizer. The yielded symbolic rule serve as a light weight light-weight surrogate of the original optimizer.

Our main findings:

Example of distilled equations from DM model:

Example of distilled equations from RP model (they are simpler than the DM surrogates, and yet more effective for the optimization task):

Distilled symbolic rules fit the optimizer quite well:

The distilled symbolic rule and underlying rules

Distilled symbolic rules perform same optimization task well, compared with the original numerical optimizer:

The light weight symbolic rules are able to be meta-tuned on large scale (ResNet-50) optimizee and get good performance:

ss large scale optimizee

The symbolic regression passed the sanity checks in the optimization tasks:

Installation Guide

The installation require no special packages. The tensorflow version we adoped is 1.14.0, and the PyTorch version we adopted is 1.7.1.

Training Guide

The three files:

torch-implementation/l2o_train_from_scratch.py

torch-implementation/l2o_symbolic_regression_stage_2_3.py

torch-implementation/l2o_evaluation.py

are pipline scripts, which integrate the multi-stage experiments. The detailed usages are specified within these files. We offer several examples below.

  • In order to train a rnn-prop model from scratch on mnist classification problem setting with 200 epochs, each epoch with length 200, unroll length 20, batch size 128, learning rate 0.001 on GPU-0, run:

    python l2o_train_from_scratch.py -m tras -p mni -n 200 -l 200 -r 20 -b 128 -lr 0.001 -d 0

  • In order to fine-tune an L2O model on the CNN optimizee with 200 epochs, each epoch length 1000, unroll length 20, batch size 64, learning rate 0.001 on GPU-0, first put the .pth model checkpoint file (the training script above will automatically save it in a new folder under current directory) under the first (0-th, in the python index) location in __WELL_TRAINED__ specified in torch-implementation/utils.py , then run the following script:

    python l2o_train_from_scratch.py -m tune -pr 0 -p cnn -n 200 -l 1000 -r 20 -b 64 -lr 0.001 -d 0

  • In order to generate data for symbolic regression, if desire to obtain 50000 samples evaluated on MNIST classification problem, with optimization trajectory length of 300 steps, using GPU-3, then run:

    python l2o_evaluation.py -m srgen -p mni -l 300 -s 50000 -d 3

  • In order to distill equation from the previously saved offline SR dataset, check and run: torch-implementation/sr_train.py

  • In order to fine-tune SR equation, check and run: torch-implementation/stage023_mid2021_update.py

  • In order to convert distilled symbolic equation into latex readable form, check and run: torch-implementation/sr_test_get_latex.py.py

  • In order to calculate how good the symbolic is fitting the original model, we use the R2-scores; to compute it, check and run: torch-implementation/sr_test_cal_r2.py

  • In order to train and run the resnet-class optimizees, check and run: torch-implementation/run_resnet.py

There are also optional tensorflow implementations of L2O, including meta-training the two benchmarks used in this paper: DM and Rnn-prop L2O. However, all steps before generating offline datasets in the pipline is only supportable with torch implementations. To do symbolic regression with tensorflow implementation, you need to manually generate records (an .npy file) of shape [N_sample, num_feature+1], which concatenate the num_feature dimensional x (symbolic regresison input) and 1 dimensional y (output), containing N_sample samples. Once behavior dataset is ready, the following steps can be shared with torch implementation.

  • In order to train the tensorflow implementation of L2O, check and run: tensorflow-implementation/train_rnnprop.py, tensorflow-implementation/train_dm.py

  • In order to evaluate the tensorflow implementation of L2O and generate offline dataset for symbolic regression, check and run: tensorflow-implementation/evaluate_rnnprop.py, tensorflow-implementation/evaluate_dm.py.

Other hints

Meta train the DM/RP/RP_si models

run the train_optimizer() functionin torch-implementation/meta.py

Evaluate the optimization performance:

run theeva_l2o_optimizer() function in torch-implementation/meta.py

RP model implementations:

TheRPOptimizer in torch-implementation/meta.py

RP_si model implementations:

same as RP, set magic=0; or more diverse input can be enabled by setting grad_features="mt+gt+mom5+mom99"

DM model implementations:

DMOptimizer in torch-implementation/utils.py

SR implementations:

torch-implementation/sr_train.py

torch-implementation/sr_test_cal_r2.py

torch-implementation/sr_test_get_latex.py

other SR options and the workflow:

srUtils.py

Citation

comming soon.

Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
A testcase generation tool for Persistent Memory Programs.

PMFuzz PMFuzz is a testcase generation tool to generate high-value tests cases for PM testing tools (XFDetector, PMDebugger, PMTest and Pmemcheck) If

Systems Research at ShiftLab 14 Jul 24, 2022
Pre-trained model, code, and materials from the paper "Impact of Adversarial Examples on Deep Learning Models for Biomedical Image Segmentation" (MICCAI 2019).

Adaptive Segmentation Mask Attack This repository contains the implementation of the Adaptive Segmentation Mask Attack (ASMA), a targeted adversarial

Utku Ozbulak 53 Jul 04, 2022
A python module for scientific analysis of 3D objects based on VTK and Numpy

A lightweight and powerful python module for scientific analysis and visualization of 3d objects.

Marco Musy 1.5k Jan 06, 2023
A self-supervised 3D representation learning framework named viewpoint bottleneck.

Pointly-supervised 3D Scene Parsing with Viewpoint Bottleneck Paper Created by Liyi Luo, Beiwen Tian, Hao Zhao and Guyue Zhou from Institute for AI In

63 Aug 11, 2022
Starter code for the ICCV 2021 paper, 'Detecting Invisible People'

Detecting Invisible People [ICCV 2021 Paper] [Website] Tarasha Khurana, Achal Dave, Deva Ramanan Introduction This repository contains code for Detect

Tarasha Khurana 28 Sep 16, 2022
Dilated Convolution with Learnable Spacings PyTorch

Dilated-Convolution-with-Learnable-Spacings-PyTorch Ismail Khalfaoui Hassani Dilated Convolution with Learnable Spacings (abbreviated to DCLS) is a no

15 Dec 09, 2022
Codes to calculate solar-sensor zenith and azimuth angles directly from hyperspectral images collected by UAV. Works only for UAVs that have high resolution GNSS/IMU unit.

UAV Solar-Sensor Angle Calculation Table of Contents About The Project Built With Getting Started Prerequisites Installation Datasets Contributing Lic

Sourav Bhadra 1 Jan 15, 2022
Using deep actor-critic model to learn best strategies in pair trading

Deep-Reinforcement-Learning-in-Stock-Trading Using deep actor-critic model to learn best strategies in pair trading Abstract Partially observed Markov

281 Dec 09, 2022
PyTorch implementation of DeepUME: Learning the Universal Manifold Embedding for Robust Point Cloud Registration (BMVC 2021)

DeepUME: Learning the Universal Manifold Embedding for Robust Point Cloud Registration [video] [paper] [supplementary] [data] [thesis] Introduction De

Natalie Lang 10 Dec 14, 2022
Implementation of gaze tracking and demo

Predicting Customer Demand by Using Gaze Detecting and Object Tracking This project is the integration of gaze detecting and object tracking. Predict

2 Oct 20, 2022
Meta-meta-learning with evolution and plasticity

Evolve plastic networks to be able to automatically acquire novel cognitive (meta-learning) tasks

5 Jun 28, 2022
We utilize deep reinforcement learning to obtain favorable trajectories for visual-inertial system calibration.

Unified Data Collection for Visual-Inertial Calibration via Deep Reinforcement Learning Update: The lastest code will be updated in this branch. Pleas

ETHZ ASL 27 Dec 29, 2022
Augmented CLIP - Training simple models to predict CLIP image embeddings from text embeddings, and vice versa.

Train aug_clip against laion400m-embeddings found here: https://laion.ai/laion-400-open-dataset/ - note that this used the base ViT-B/32 CLIP model. S

Peter Baylies 55 Sep 13, 2022
Official Pytorch Implementation of Adversarial Instance Augmentation for Building Change Detection in Remote Sensing Images.

IAug_CDNet Official Implementation of Adversarial Instance Augmentation for Building Change Detection in Remote Sensing Images. Overview We propose a

53 Dec 02, 2022
Segmentation and Identification of Vertebrae in CT Scans using CNN, k-means Clustering and k-NN

Segmentation and Identification of Vertebrae in CT Scans using CNN, k-means Clustering and k-NN If you use this code for your research, please cite ou

41 Dec 08, 2022
A Rao-Blackwellized Particle Filter for 6D Object Pose Tracking

PoseRBPF: A Rao-Blackwellized Particle Filter for 6D Object Pose Tracking PoseRBPF Paper Self-supervision Paper Pose Estimation Video Robot Manipulati

NVIDIA Research Projects 107 Dec 25, 2022
Free like Freedom

This is all very much a work in progress! More to come! ( We're working on it though! Stay tuned!) Installation Open an Anaconda Prompt (in Windows, o

2.3k Jan 04, 2023
DiscoBox: Weakly Supervised Instance Segmentation and Semantic Correspondence from Box Supervision

The Official PyTorch Implementation of DiscoBox: Weakly Supervised Instance Segmentation and Semantic Correspondence from Box Supervision

Shiyi Lan 3 Oct 15, 2021
Arabic Car License Recognition. A solution to the kaggle competition Machathon 3.0.

Transformers Arabic licence plate recognition 🚗 Solution to the kaggle competition Machathon 3.0. Ranked in the top 6️⃣ at the final evaluation phase

Noran Hany 17 Dec 04, 2022
GrailQA: Strongly Generalizable Question Answering

GrailQA is a new large-scale, high-quality KBQA dataset with 64,331 questions annotated with both answers and corresponding logical forms in different syntax (i.e., SPARQL, S-expression, etc.). It ca

OSU DKI Lab 76 Dec 21, 2022