Neural network pruning for finding a sparse computational model for controlling a biological motor task.

Overview

MothPruning

Scientific Overview

Originally inspired by biological nervous systems, deep neural networks (DNNs) are powerful computational tools for modeling complex systems. DNNs are used in a diversity of domains and have helped solve some of the most intractable problems in physics, biology, and computer science. Despite their prevalence, the use of DNNs as a modeling tool comes with some major downsides. DNNs are highly overparameterized, which often results in them being difficult to generalize and interpret, as well as being incredibly computationally expensive. Unlike DNNs, which are often trained until they reach the highest accuracy possible, biological networks have to balance performance with robustness to a noisy and dynamic environment. Biological neural systems use a variety of mechanisms to promote specialized and efficient pathways capable of performing complex tasks in the presence of noise. One such mechanism, synaptic pruning, plays a significant role in refining task-specific behaviors. Synaptic pruning results in a more sparsely connected network that can still perform complex cognitive and motor tasks. Here, we draw inspiration from biology and use DNNs and the method of neural network pruning to find a sparse computational model for controlling a biological motor task.

In this work, we use the inertial dynamics model in [2] to simulate examples of M. sexta hovering flight. These data are used to train a DNN to learn the controllers for hovering. Drawing inspiration from pruning in biological neural systems, we sparsify the network using neural network pruning. Here, we prune weights based simply on their magnitudes, removing those weights closest to zero. Insects must maneuver through high noise environments to accomplish controlled flight. It is often assumed that there is a trade-off between perfect flight control and robustness to noise and that the sensory data may be limited by the signal-to-noise ratio. Thus the network need not train for the most accurate model since in practice noise prevents high-fidelity models from exhibiting their underlying accuracy. Rather, we seek to find the sparsest model capable of performing the task given the noisy environment. We employed two methods for neural network pruning: either through manually setting weights to zero or by utilizing binary masking layers. Furthermore, the DNN is pruned sequentially, meaning groups of weights are removed slowly from the network, with retraining in-between successive prunes, until a target sparsity is reached. Monte Carlo simulations are also used to quantify the statistical distribution of network weights during pruning given random initialization of network weights.

For more information, please see our paper [1].

This is an image!

Project Description

The deep, fully-connected neural network was constructed with ten input variables and seven output variables. The initial and final state space conditions are the inputs to the network: i, i, i, i, i, i, f, f, f, and f. The network predicts the control variables and the final derivatives of the state space in its output layer: x, y, , f, f, f, and f.

After the fully-connected network is trained to a minimum error, we used the method of neural network pruning to promote sparsity between the network layers. In this work, a target sparsity (percentage of pruned network weights) is specified and the smallest magnitude weights are forced to zero. The network is then retrained until a minimum error is reached. This process is repeated until most of the weights have been pruned from the network.

The training and pruning protocols were developed using Keras with the TensorFlow backend. To scale up training for the statistical analysis of many networks, the training and pruning protocols were parallelized using the Jax framework.

To ensure weights remain pruned during retraining, we implemented the pruning functionality of a TensorFlow built toolkit called the Model Optimization Toolkit. The toolkit contains functions for pruning deep neural networks. In the Model Optimization Toolkit, pruning is achieved through the use of binary masking layers that are multiplied element-wise to each weight matrix in the network.

To be able to train and analyze many neural networks, the training and pruning protocols were parallelized in the Jax framework. Jax however does not come with a toolkit for pruning, therefore pruning by way of the binary masking matrices was coded into the training loop.

Installation

Create new conda environment with tools for generating data and training network (Note that this environment requires a GPU and the correct NVIDIA drivers).

conda env create -f environment_ODE_DL.yml

Create kernelspec (so you can see this kernel in JupyterLab).

conda activate [environment name]
python -m ipykernel install --user --name [environment name]
conda deactivate

To install Jax and Flax please follow the instructions on the Jax Github.

Data

To use the TensorFlow version of this code, you need to gerenate simulations of moth hovering for the data. The Jax version (multi-network train and prune) has data provided in this repository.

cd MothMachineLearning/Underactuated/GenerateData

and use 010_OneTorqueParallelSims.ipynb to generate the simulations.

How to use

The following guide walks through the process of training and pruning many networks in parallel using the Jax framework. However, the TensorFlow code is also provided for experimentation and visualization.

Step 1: Train networks

cd MothMachineLearning/Underactuated/TrainNetwork/multiNetPrune/

First we train and prune the desired number of networks in parallel using the Jax framework. Choose the number of networks you wish to train/prune in parallel by adjusting the numParallel parameter. You can also define the number of layers, units, and other hyperparameters. Use the command

python3 step1_train.py

to train and prune the networks in parallel.

Step 2: Evaluate at prunes

Next, the networks need to be evaulated at each prune. Use the command

python3 step2_pruneEval.py

to evaluate the networks at each prune.

Step 3: Pre-process networks

This code prepares the networks for sparse network identification (explained in the next step). It essentially just reorganizes the data. Open and run step3_preprocess.ipynb to preprocess, making sure to change modeltimestamp and the file names to the correct ones for your run.

Step 4: Find sparse networks

This codes finds the optimally sparse networks. For each network, the most pruned version whose loss is below a specified threshold (here 0.001) is kept. For example, the image below is a single network that has gone through the sequential pruning process and the red line specifies the defined threshold. For this example, the optimally sparse network is the one pruned by 94% (i.e. 6% of the original weights remain).

This is an image!

The sparse networks are collected and saved to a file called sparseNetworks.pkl. Open and run step4_findSparse.ipynb, making sure to change modeltimestamp and the file names to the correct ones for your run.

Note that if a network does not have a single prune that is below the loss threshold, it will be skipped and not included in the list of sparseNetworks. For example, if you trained and pruned 10 networks and 3 did not have a prune below a loss of 0.001, the list sparseNetworks will be length 7.

References

[1] Zahn, O., Bustamante, Jr J., Switzer, C., Daniel, T., and Kutz, J. N. (2022). Pruning deep neural networks generates a sparse, bio-inspired nonlinear controller for insect flight.

[2] Bustamante, Jr J., Ahmed, M., Deora, T., Fabien, B., and Daniel, T. (2021). Abdominal movements in insect flight reshape the role of non-aerodynamic structures for flight maneuverability. J. Integrative and Comparative Biology. In revision.

Owner
Olivia Thomas
Physics graduate student at the University of Washington
Olivia Thomas
Implementation of ResMLP, an all MLP solution to image classification, in Pytorch

ResMLP - Pytorch Implementation of ResMLP, an all MLP solution to image classification out of Facebook AI, in Pytorch Install $ pip install res-mlp-py

Phil Wang 178 Dec 02, 2022
Fine-tune pretrained Convolutional Neural Networks with PyTorch

Fine-tune pretrained Convolutional Neural Networks with PyTorch. Features Gives access to the most popular CNN architectures pretrained on ImageNet. A

Alex Parinov 694 Nov 23, 2022
Blender Add-On for slicing meshes with planes

MeshSlicer Blender Add-On for slicing meshes with multiple overlapping planes at once. This is a simple Blender addon to slice a silmple mesh with mul

52 Dec 12, 2022
Predicts an answer in yes or no.

Oui-ou-non-prediction Predicts an answer in 'yes' or 'no'. It is based on the game 'effeuiller la marguerite' in which the person plucks flower petals

Ananya Gupta 1 Jan 15, 2022
Deep Learning Visuals contains 215 unique images divided in 23 categories

Deep Learning Visuals contains 215 unique images divided in 23 categories (some images may appear in more than one category). All the images were originally published in my book "Deep Learning with P

Daniel Voigt Godoy 1.3k Dec 28, 2022
Python3 / PyTorch implementation of the following paper: Fine-grained Semantics-aware Representation Enhancement for Self-supervisedMonocular Depth Estimation. ICCV 2021 (oral)

FSRE-Depth This is a Python3 / PyTorch implementation of FSRE-Depth, as described in the following paper: Fine-grained Semantics-aware Representation

77 Dec 28, 2022
Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Pytorch implementation of MaskGIT: Masked Generative Image Transformer

Dominic Rampas 247 Dec 16, 2022
A Simple Key-Value Data-store written in Python

mercury-db This is a File Based Key-Value Datastore that supports basic CRUD (Create, Read, Update, Delete) operations developed using Python. The dat

Vaidhyanathan S M 1 Jan 09, 2022
External Attention Network

Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks paper : https://arxiv.org/abs/2105.02358 Jittor code will come soon

MenghaoGuo 357 Dec 11, 2022
SAN for Product Attributes Prediction

SAN Heterogeneous Star Graph Attention Network for Product Attributes Prediction This repository contains the official PyTorch implementation for ADVI

Xuejiao Zhao 9 Dec 12, 2022
This repository contains PyTorch models for SpecTr (Spectral Transformer).

SpecTr: Spectral Transformer for Hyperspectral Pathology Image Segmentation This repository contains PyTorch models for SpecTr (Spectral Transformer).

Boxiang Yun 45 Dec 13, 2022
Patch-Diffusion Code (AAAI2022)

Patch-Diffusion This is an official PyTorch implementation of "Patch Diffusion: A General Module for Face Manipulation Detection" in AAAI2022. Require

H 7 Nov 02, 2022
CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Generation

CPT This repository contains code and checkpoints for CPT. CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Gener

fastNLP 341 Dec 29, 2022
Gym Threat Defense

Gym Threat Defense The Threat Defense environment is an OpenAI Gym implementation of the environment defined as the toy example in Optimal Defense Pol

Hampus Ramström 5 Dec 08, 2022
TriMap: Large-scale Dimensionality Reduction Using Triplets

TriMap TriMap is a dimensionality reduction method that uses triplet constraints to form a low-dimensional embedding of a set of points. The triplet c

Ehsan Amid 235 Dec 24, 2022
ColossalAI-Benchmark - Performance benchmarking with ColossalAI

Benchmark for Tuning Accuracy and Efficiency Overview The benchmark includes our

HPC-AI Tech 31 Oct 07, 2022
PyTorch implementation of "Transparency by Design: Closing the Gap Between Performance and Interpretability in Visual Reasoning"

Transparency-by-Design networks (TbD-nets) This repository contains code for replicating the experiments and visualizations from the paper Transparenc

David Mascharka 351 Nov 18, 2022
DGN pymarl - Implementation of DGN on Pymarl, which could be trained by VDN or QMIX

This is the implementation of DGN on Pymarl, which could be trained by VDN or QM

4 Nov 23, 2022
A high-level Python library for Quantum Natural Language Processing

lambeq About lambeq is a toolkit for quantum natural language processing (QNLP). Documentation: https://cqcl.github.io/lambeq/ Getting started Prerequ

Cambridge Quantum 315 Jan 01, 2023
It is a system used to detect bone fractures. using techniques deep learning and image processing

MohammedHussiengadalla-Intelligent-Classification-System-for-Bone-Fractures It is a system used to detect bone fractures. using techniques deep learni

Mohammed Hussien 7 Nov 11, 2022