Learning Neural Network Subspaces

Overview

Learning Neural Network Subspaces

Welcome to the codebase for Learning Neural Network Subspaces by Mitchell Wortsman, Maxwell Horton, Carlos Guestrin, Ali Farhadi, Mohammad Rastegari.

Figure1

Abstract

Recent observations have advanced our understanding of the neural network optimization landscape, revealing the existence of (1) paths of high accuracy containing diverse solutions and (2) wider minima offering improved performance. Previous methods observing diverse paths require multiple training runs. In contrast we aim to leverage both property (1) and (2) with a single method and in a single training run. With a similar computational cost as training one model, we learn lines, curves, and simplexes of high-accuracy neural networks. These neural network subspaces contain diverse solutions that can be ensembled, approaching the ensemble performance of independently trained networks without the training cost. Moreover, using the subspace midpoint boosts accuracy, calibration, and robustness to label noise, outperforming Stochastic Weight Averaging.

Code Overview

In this repository we walk through learning neural network subspaces with PyTorch. We will ground the discussion with learning a line of neural networks. In our code, a line is defined by endpoints weight and weight1 and a point on the line is given by w = (1 - alpha) * weight + alpha * weight1 for some alpha in [0,1].

Algorithm 1 (see paper) works as follows:

  1. weight and weight1 are initialized independently.
  2. For each batch data, targets, alpha is chosen uniformly from [0,1] and the weights w = (1 - alpha) * weight + alpha * weight1 are used in the forward pass.
  3. The regularization term is computed (see Eq. 3).
  4. With loss.backward() and optimizer.step() the endpoints weight and weight1 are updated.

Instead of using a regular nn.Conv2d we instead use a SubspaceConv (found in modes/modules.py).

class SubspaceConv(nn.Conv2d):
    def forward(self, x):
        w = self.get_weight()
        x = F.conv2d(
            x,
            w,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )
        return x

For each subspace type (lines, curves, and simplexes) the function get_weight must be implemented. For lines we use:

class TwoParamConv(SubspaceConv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight1 = nn.Parameter(torch.zeros_like(self.weight))

    def initialize(self, initialize_fn):
        initialize_fn(self.weight1)

class LinesConv(TwoParamConv):
    def get_weight(self):
        w = (1 - self.alpha) * self.weight + self.alpha * self.weight1
        return w

Note that the other endpoint weight is instantiated and initialized by nn.Conv2d. Also note that there is an equivalent implementation for batch norm layers also found in modes/modules.py.

Now we turn to the training logic which appears in trainers/train_one_dim_subspaces.py. In the snippet below we assume we are not training with the layerwise variant (args.layerwise = False) and we are drawing only one sample from the subspace (args.num_samples = 1).

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(args.device), target.to(args.device)

    alpha = np.random.uniform(0, 1)
    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
            setattr(m, f"alpha", alpha)

    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)

All that's left is to compute the regularization term and call backward. For lines, this is given by the snippet below.

    num = 0.0
    norm = 0.0
    norm1 = 0.0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            num += (self.weight * self.weight1).sum()
            norm += self.weight.pow(2).sum()
            norm1 += self.weight1.pow(2).sum()
    loss += args.beta * (num.pow(2) / (norm * norm1))

    loss.backward()

    optimizer.step()

Training Lines, Curves, and Simplexes

We now walkthrough generating the plots in Figures 4 and 5 of the paper. Before running code please install PyTorch and Tensorboard (for making plots you will also need tex on your computer). Note that this repository differs from that used to generate the figures in the paper, as the latter leveraged Apple's internal tools. Accordingly there may be some bugs and we encourage you to submit an issue or send an email if you run into any problems.

In this example walkthrough we consider TinyImageNet, which we download to ~/data using a script such as this. To run standard training and ensemble the trained models, use the following command:

python experiment_configs/tinyimagenet/ensembles/train_ensemble_members.py
python experiment_configs/tinyimagenet/ensembles/eval_ensembles.py

Note that if your data is not in ~/data please change the paths in these experiment configs. Logs and checkpoints be saved in learning-subspaces-results, although this path can also be changed.

For one dimensional subspaces, use the following command to train:

python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_lines.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_lines_layerwise.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/train_curves.py

To evaluate (i.e. generate the data for Figure 4) use:

python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_lines.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_lines_layerwise.py
python experiment_configs/tinyimagenet/one_dimensional_subspaces/eval_curves.py

We recommend looking at the experiment config files before running, which can be modified to change the type of model, number of random seeds. The default in these configs is 2 random seeds.

Analogously, to train simplexes use:

python experiment_configs/tinyimagenet/simplexes/train_simplexes.py
python experiment_configs/tinyimagenet/simplexes/train_simplexes_layerwise.py

For generating plots like those in Figure 4 and 5 use:

python analyze_results/tinyimagenet/one_dimensional_subspaces.py
python analyze_results/tinyimagenet/simplexes.py

Equivalent configs exist for other datasets, and the configs can be modified to add label noise, experiment with other models, and more. Also, if there is any functionality missing from this repository that you would like please also submit an issue.

Bibtex

@article{wortsman2021learning,
  title={Learning Neural Network Subspaces},
  author={Wortsman, Mitchell and Horton, Maxwell and Guestrin, Carlos and Farhadi, Ali and Rastegari, Mohammad},
  journal={arXiv preprint arXiv:2102.10472},
  year={2021}
}
Owner
Apple
Apple
A very short and easy implementation of Quantile Regression DQN

Quantile Regression DQN Quantile Regression DQN a Minimal Working Example, Distributional Reinforcement Learning with Quantile Regression (https://arx

Arsenii Senya Ashukha 80 Sep 17, 2022
Pytorch implementation of PTNet for high-resolution and longitudinal infant MRI synthesis

Pyramid Transformer Net (PTNet) Project | Paper Pytorch implementation of PTNet for high-resolution and longitudinal infant MRI synthesis. PTNet: A Hi

Xuzhe Johnny Zhang 6 Jun 08, 2022
A PyTorch Lightning solution to training OpenAI's CLIP from scratch.

train-CLIP 📎 A PyTorch Lightning solution to training CLIP from scratch. Goal ⚽ Our aim is to create an easy to use Lightning implementation of OpenA

Cade Gordon 396 Dec 30, 2022
Replication attempt for the Protein Folding Model

RGN2-Replica (WIP) To eventually become an unofficial working Pytorch implementation of RGN2, an state of the art model for MSA-less Protein Folding f

Eric Alcaide 36 Nov 29, 2022
Repo for my Tensorflow/Keras CV experiments. Mostly revolving around the Danbooru20xx dataset

SW-CV-ModelZoo Repo for my Tensorflow/Keras CV experiments. Mostly revolving around the Danbooru20xx dataset Framework: TF/Keras 2.7 Training SQLite D

20 Dec 27, 2022
Object tracking using YOLO and a tracker(KCF, MOSSE, CSRT) in openCV

Object tracking using YOLO and a tracker(KCF, MOSSE, CSRT) in openCV File YOLOv3 weight can be downloaded

Ngoc Quyen Ngo 2 Mar 27, 2022
PyTorch implementation of ARM-Net: Adaptive Relation Modeling Network for Structured Data.

A ready-to-use framework of latest models for structured (tabular) data learning with PyTorch. Applications include recommendation, CRT prediction, healthcare analytics, and etc.

48 Nov 30, 2022
ZSL-KG is a general-purpose zero-shot learning framework with a novel transformer graph convolutional network (TrGCN) to learn class representation from common sense knowledge graphs.

ZSL-KG is a general-purpose zero-shot learning framework with a novel transformer graph convolutional network (TrGCN) to learn class representa

Bats Research 94 Nov 21, 2022
Synthesizing and manipulating 2048x1024 images with conditional GANs

pix2pixHD Project | Youtube | Paper Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic image-to-image translatio

NVIDIA Corporation 6k Dec 27, 2022
Retina blood vessel segmentation with a convolutional neural network

Retina blood vessel segmentation with a convolution neural network (U-net) This repository contains the implementation of a convolutional neural netwo

Orobix 1.2k Jan 06, 2023
A simple root calculater for python

Root A simple root calculater Usage/Examples python3 root.py 9 3 4 # Order: number - grid - number of decimals # Output: 2.08

Reza Hosseinzadeh 5 Feb 10, 2022
Human head pose estimation using Keras over TensorFlow.

RealHePoNet: a robust single-stage ConvNet for head pose estimation in the wild.

Rafael Berral Soler 71 Jan 05, 2023
Reproducing code of hair style replacement method from Barbershorp.

Barbershorp Reproducing code of hair style replacement method from Barbershorp. Also reproduces II2S, an improved version of Image2StyleGAN. Requireme

1 Dec 24, 2021
Train an RL agent to execute natural language instructions in a 3D Environment (PyTorch)

Gated-Attention Architectures for Task-Oriented Language Grounding This is a PyTorch implementation of the AAAI-18 paper: Gated-Attention Architecture

Devendra Chaplot 234 Nov 05, 2022
This is my codes that can visualize the psnr image in testing videos.

CVPR2018-Baseline-PSNRplot This is my codes that can visualize the psnr image in testing videos. Future Frame Prediction for Anomaly Detection – A New

Wenhao Yang 12 May 29, 2021
A booklet on machine learning systems design with exercises

Machine Learning Systems Design Read this booklet here. This booklet covers four main steps of designing a machine learning system: Project setup Data

Chip Huyen 7.6k Jan 08, 2023
Optimize Trading Strategies Using Freqtrade

Optimize trading strategy using Freqtrade Short demo on building, testing and optimizing a trading strategy using Freqtrade. The DevBootstrap YouTube

DevBootstrap 139 Jan 01, 2023
Identifying a Training-Set Attack’s Target Using Renormalized Influence Estimation

Identifying a Training-Set Attack’s Target Using Renormalized Influence Estimation By: Zayd Hammoudeh and Daniel Lowd Paper: Arxiv Preprint Coming soo

Zayd Hammoudeh 2 Oct 08, 2022
Official code repository for ICCV 2021 paper: Gravity-Aware Monocular 3D Human Object Reconstruction

GraviCap Official code repository for ICCV 2021 paper: Gravity-Aware Monocular 3D Human Object Reconstruction. Gravity-Aware Monocular 3D Human-Object

Rishabh Dabral 15 Dec 09, 2022
This repo. is an implementation of ACFFNet, which is accepted for in Image and Vision Computing.

Attention-Guided-Contextual-Feature-Fusion-Network-for-Salient-Object-Detection This repo. is an implementation of ACFFNet, which is accepted for in I

5 Nov 21, 2022