Learning Neural Network Subspaces

Overview

Learning Neural Network Subspaces

Welcome to the codebase for Learning Neural Network Subspaces by Mitchell Wortsman, Maxwell Horton, Carlos Guestrin, Ali Farhadi, Mohammad Rastegari.

Figure1

Abstract

Recent observations have advanced our understanding of the neural network optimization landscape, revealing the existence of (1) paths of high accuracy containing diverse solutions and (2) wider minima offering improved performance. Previous methods observing diverse paths require multiple training runs. In contrast we aim to leverage both property (1) and (2) with a single method and in a single training run. With a similar computational cost as training one model, we learn lines, curves, and simplexes of high-accuracy neural networks. These neural network subspaces contain diverse solutions that can be ensembled, approaching the ensemble performance of independently trained networks without the training cost. Moreover, using the subspace midpoint boosts accuracy, calibration, and robustness to label noise, outperforming Stochastic Weight Averaging.

Code Overview

In this repository we walk through learning neural network subspaces with PyTorch. We will ground the discussion with learning a line of neural networks. In our code, a line is defined by endpoints weight and weight1 and a point on the line is given by w = (1 - alpha) * weight + alpha * weight1 for some alpha in [0,1].

Algorithm 1 (see paper) works as follows:

  1. weight and weight1 are initialized independently.
  2. For each batch data, targets, alpha is chosen uniformly from [0,1] and the weights w = (1 - alpha) * weight + alpha * weight1 are used in the forward pass.
  3. The regularization term is computed (see Eq. 3).
  4. With loss.backward() and optimizer.step() the endpoints weight and weight1 are updated.

Instead of using a regular nn.Conv2d we instead use a SubspaceConv (found in modes/modules.py).

class SubspaceConv(nn.Conv2d):
    def forward(self, x):
        w = self.get_weight()
        x = F.conv2d(
            x,
            w,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )
        return x

For each subspace type (lines, curves, and simplexes) the function get_weight must be implemented. For lines we use:

class TwoParamConv(SubspaceConv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight1 = nn.Parameter(torch.zeros_like(self.weight))

    def initialize(self, initialize_fn):
        initialize_fn(self.weight1)

class LinesConv(TwoParamConv):
    def get_weight(self):
        w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
        return w

Note that the other endpoint weight is instantiated and initialized by nn.Conv2d. Also note that there is an equivalent implementation for batch norm layers also found in modes/modules.py.

Now we turn to the training logic which appears in trainers/train_one_dim_subspaces.py. In the snippet below we assume we are not training with the layerwise variant (args.layerwise = False) and we are drawing only one sample from the subspace (args.num_samples = 1).

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(args.device), target.to(args.device)

    alpha = np.random.uniform(0, 1)
    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
            setattr(m, f"alpha", alpha)

    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)

All that's left is to compute the regularization term and call backward. For lines, this is given by the snippet below.

    num = 0.0
    norm = 0.0
    norm1 = 0.0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            num += (self.weight * self.weight1).sum()
            norm += self.weight.pow(2).sum()
            norm1 += self.weight1.pow(2).sum()
    loss += args.beta * (num.pow(2) / (norm * norm1))

    loss.backward()

    optimizer.step()

Training Lines, Curves, and Simplexes

We now walkthrough generating the plots in Figures 4 and 5 of the paper. Before running code please install PyTorch and Tensorboard (for making plots you will also need tex on your computer). Note that this repository differs from that used to generate the figures in the paper, as the latter leveraged Apple's internal tools. Accordingly there may be some bugs and we encourage you to submit an issue or send an email if you run into any problems.

In this example walkthrough we consider TinyImageNet, which we download to ~/data using a script such as this. To run standard training and ensemble the trained models, use the following command:

python experiment_configs/tinyimagenet/ensembles/train_ensemble_members.py
python experiment_configs/tinyimagenet/ensembles/eval_ensembles.py

Note that if your data is not in ~/data please change the paths in these experiment configs. Logs and checkpoints be saved in learning-subspaces-results, although this path can also be changed.

For one dimensional subspaces, use the following command to train:

python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_lines.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_lines_layerwise.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_curves.py

To evaluate (i.e. generate the data for Figure 4) use:

python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_lines.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_lines_layerwise.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_curves.py

We recommend looking at the experiment config files before running, which can be modified to change the type of model, number of random seeds. The default in these configs is 2 random seeds.

Analogously, to train simplexes use:

python experiment_configs/tinyimagenet/simplexes/train_simplexes.py
python experiment_configs/tinyimagenet/simplexes/train_simplexes_layerwise.py

For generating plots like those in Figure 4 and 5 use:

python analyze_results/tinyimagenet/one_dimensional_subspaces.py
python analyze_results/tinyimagenet/simplexes.py

Equivalent configs exist for other datasets, and the configs can be modified to add label noise, experiment with other models, and more. Also, if there is any functionality missing from this repository that you would like please also submit an issue.

Bibtex

@article{wortsman2021learning,
  title={Learning Neural Network Subspaces},
  author={Wortsman, Mitchell and Horton, Maxwell and Guestrin, Carlos and Farhadi, Ali and Rastegari, Mohammad},
  journal={arXiv preprint arXiv:2102.10472},
  year={2021}
}
Owner
Apple
Apple
Implementation of Change-Based Exploration Transfer (C-BET)

Implementation of Change-Based Exploration Transfer (C-BET), as presented in Interesting Object, Curious Agent: Learning Task-Agnostic Exploration.

Simone Parisi 29 Dec 04, 2022
RCD: Relation Map Driven Cognitive Diagnosis for Intelligent Education Systems

RCD: Relation Map Driven Cognitive Diagnosis for Intelligent Education Systems This is our implementation for the paper: Weibo Gao, Qi Liu*, Zhenya Hu

BigData Lab @USTC 中科大大数据实验室 10 Oct 16, 2022
MEND: Model Editing Networks using Gradient Decomposition

MEND: Model Editing Networks using Gradient Decomposition Setup Environment This codebase uses Python 3.7.9. Other versions may work as well. Create a

Eric Mitchell 141 Dec 02, 2022
Patch-Diffusion Code (AAAI2022)

Patch-Diffusion This is an official PyTorch implementation of "Patch Diffusion: A General Module for Face Manipulation Detection" in AAAI2022. Require

H 7 Nov 02, 2022
Le dataset des images du projet d'IA de 2021

face-mask-dataset-ilc-2021 Le dataset des images du projet d'IA de 2021, Indiquez vos id git dans la issue pour les droits TL;DR: Choisir 200 images J

7 Nov 15, 2021
Library to enable Bayesian active learning in your research or labeling work.

Bayesian Active Learning (BaaL) BaaL is an active learning library developed at ElementAI. This repository contains techniques and reusable components

ElementAI 687 Dec 25, 2022
Official implementation of "An Image is Worth 16x16 Words, What is a Video Worth?" (2021 paper)

An Image is Worth 16x16 Words, What is a Video Worth? paper Official PyTorch Implementation Gilad Sharir, Asaf Noy, Lihi Zelnik-Manor DAMO Academy, Al

213 Nov 12, 2022
Deep Surface Reconstruction from Point Clouds with Visibility Information

Data, code and pretrained models for the paper Deep Surface Reconstruction from Point Clouds with Visibility Information.

Raphael Sulzer 23 Jan 04, 2023
MogFace: Towards a Deeper Appreciation on Face Detection

MogFace: Towards a Deeper Appreciation on Face Detection Introduction In this repo, we propose a promising face detector, termed as MogFace. Our MogFa

48 Dec 20, 2022
The code for our paper "NSP-BERT: A Prompt-based Zero-Shot Learner Through an Original Pre-training Task —— Next Sentence Prediction"

The code for our paper "NSP-BERT: A Prompt-based Zero-Shot Learner Through an Original Pre-training Task —— Next Sentence Prediction"

Sun Yi 201 Nov 21, 2022
[ICCV 2021] Official Tensorflow Implementation for "Single Image Defocus Deblurring Using Kernel-Sharing Parallel Atrous Convolutions"

KPAC: Kernel-Sharing Parallel Atrous Convolutional block This repository contains the official Tensorflow implementation of the following paper: Singl

Hyeongseok Son 50 Dec 29, 2022
You can draw the corresponding bounding box into the image and save it according to the result file (txt format) run by the tracker.

You can draw the corresponding bounding box into the image and save it according to the result file (txt format) run by the tracker.

Huiyiqianli 42 Dec 06, 2022
Transferable Unrestricted Attacks, which won 1st place in CVPR’21 Security AI Challenger: Unrestricted Adversarial Attacks on ImageNet.

Transferable Unrestricted Adversarial Examples This is the PyTorch implementation of the Arxiv paper: Towards Transferable Unrestricted Adversarial Ex

equation 16 Dec 29, 2022
Stratified Transformer for 3D Point Cloud Segmentation (CVPR 2022)

Stratified Transformer for 3D Point Cloud Segmentation Xin Lai*, Jianhui Liu*, Li Jiang, Liwei Wang, Hengshuang Zhao, Shu Liu, Xiaojuan Qi, Jiaya Jia

DV Lab 195 Jan 01, 2023
Hypersearch weight debugging and losses tutorial

tutorial Activate tensorboard option Running TensorBoard remotely When working on a remote server, you can use SSH tunneling to forward the port of th

1 Dec 11, 2021
A PyTorch implementation of the Relational Graph Convolutional Network (RGCN).

Torch-RGCN Torch-RGCN is a PyTorch implementation of the RGCN, originally proposed by Schlichtkrull et al. in Modeling Relational Data with Graph Conv

Thiviyan Singam 66 Nov 30, 2022
Quasi-Dense Similarity Learning for Multiple Object Tracking, CVPR 2021 (Oral)

Quasi-Dense Tracking This is the offical implementation of paper Quasi-Dense Similarity Learning for Multiple Object Tracking. We present a trailer th

ETH VIS Research Group 327 Dec 27, 2022
Context Decoupling Augmentation for Weakly Supervised Semantic Segmentation

Context Decoupling Augmentation for Weakly Supervised Semantic Segmentation The code of: Context Decoupling Augmentation for Weakly Supervised Semanti

54 Dec 12, 2022
CVPR 2021: "The Spatially-Correlative Loss for Various Image Translation Tasks"

Spatially-Correlative Loss arXiv | website We provide the Pytorch implementation of "The Spatially-Correlative Loss for Various Image Translation Task

Chuanxia Zheng 89 Jan 04, 2023
Incremental Cross-Domain Adaptation for Robust Retinopathy Screening via Bayesian Deep Learning

Incremental Cross-Domain Adaptation for Robust Retinopathy Screening via Bayesian Deep Learning Update (September 18th, 2021) A supporting document de

Taimur Hassan 1 Mar 16, 2022