What can linearized neural networks actually say about generalization?

Overview

What can linearized neural networks actually say about generalization?

This is the source code to reproduce the experiments of the NeurIPS 2021 paper "What can linearized neural networks actually say about generalization?" by Guillermo Ortiz-Jimenez, Seyed-Mohsen Moosavi-Dezfooli and Pascal Frossard.

Dependencies

To run the code, please install all its dependencies by running:

$ pip install -r requirements.txt

This assumes that you have access to a Linux machine with an NVIDIA GPU with CUDA>=11.1. Otherwise, please check the instructions to install JAX with your setup in the corresponding repository.

In general, all scripts are parameterized using hydra and their configuration files can be found in the config/ folder.

Experiments

The repository contains code to reproduce the following experiments:

Spectral decomposition of NTK

To generate our new benchmark, consisting on the eigenfunctions of the NTK at initialization, please run the python script compute_ntk.py selecting a desired model (e.g., mlp, lenet or resnet18) and supporting dataset (e.g., cifar10 or mnist). This can be done by running

$ python compute_ntk.py model=lenet data.dataset=cifar10

This script will save the eigenvalues, eigenfunctions and weights of the model under artifacts/eigenfunctions/{data.dataset}/{model}/.

For other configuration options, please consult the configuration file config/compute-ntk/config.yaml.

Warning

Take into account that, for large models, this computation can take very long. For example, it took us two days to compute the full eigenvalue decomposition of the NTK of one randomly initialized ResNet18 using 4 NVIDIA V100 GPUs. The estimation of eigenvectors for the MLP or the LeNet, on the other hand, can be done in a matter of minutes, depending on the number of GPUs available and the selected batch_size

Training on binary eigenfunctions

Once you have estimated the eigenfunctions of the NTK, you should be able to train on any of them. To that end, select the desired label_idx (i.e. eigenfunction index), model and dataset, and run

$ python train_ntk.py label_idx=100 model=lenet data.dataset=cifar10 linearize=False

You can choose to train with the original non-linear network, or its linear approximation by specifying your choice with the flag linearize. For the non-linear models, this script also computes the final alignment of the end NTK with the target function, which it stores under artifacts/eigenfunctions/{data.dataset}/{model}/alignment_plots/

To see the different supported training options, please consult the configuration file config/train-ntk/config.yaml.

Estimation of NADs

We also provide code to compute the NADs of a CNN architecture (e.g., lenet or resnet18) using the alignment with the NTK at initialization. To do so, please run

$ python compute_nads.py model=lenet

This script will save the eigenvalues, NADs and weights of the model under artifacts/nads/{model}/.

For other configuration options, please consult the configuration file config/compute-nads/config.yaml.

Training on linearly separable datasets

Once you have estimated the NADs of a network, you should be able to train on linearly separable datasets with a single NAD as discriminative feature. To that end, select the desired label_idx (i.e. NAD index) and model, and run

$ python train_nads.py label_idx=100 model=lenet linearize=False

You can choose to train with the original non-linear network, or its linear approximation by specifying your choice with the flag linearize.

To see the different supported training options, please consult the configuration file config/train-nads/config.yaml.

Comparison of training dynamics with pretrained NTK

We also provide code to compare the training dynamics of the linearize network at initialization, and after non-linear pretraining, to estimate a particular eigenfunction of the NTK at initialization. To do this, please run

$ python pretrained_ntk_comparison.py label_idx=100 model=lenet data.dataset=cifar10

To see the different supported training options, please consult the configuration file config/pretrained_ntk_comparison/config.yaml.

Training on CIFAR2

Finally, you can train a neural network and its linearize approximation on the binary version of CIFAR10, i.e., CIFAR2. To do this, please run

$ python train_cifar.py model=lenet linearize=False

To see the different supported training options, please consult the configuration file config/binary-cifar/config.yaml.

Reference

If you use this code, please cite the following paper:

@InCollection{Ortiz-JimenezNeurIPS2021,
  title = {What can linearized neural networks actually say about generalization?},
  author = {{Ortiz-Jimenez}, Guillermo and {Moosavi-Dezfooli}, Seyed-Mohsen and Frossard, Pascal},
  booktitle = {Advances in Neural Information Processing Systems 35},
  month = Dec,
  year = {2021}
}
Owner
gortizji
PhD student at EPFL
gortizji
A pytorch implementation of Paper "Improved Training of Wasserstein GANs"

WGAN-GP An pytorch implementation of Paper "Improved Training of Wasserstein GANs". Prerequisites Python, NumPy, SciPy, Matplotlib A recent NVIDIA GPU

Marvin Cao 1.4k Dec 14, 2022
Code for paper: Group-CAM: Group Score-Weighted Visual Explanations for Deep Convolutional Networks

Group-CAM By Zhang, Qinglong and Rao, Lu and Yang, Yubin [State Key Laboratory for Novel Software Technology at Nanjing University] This repo is the o

zhql 98 Nov 16, 2022
Turn based roguelike in python

pyTB Turn based roguelike in python Documentation can be found here: http://mcgillij.github.io/pyTB/index.html Screenshot Dependencies Written in Pyth

Jason McGillivray 4 Sep 29, 2022
A new play-and-plug method of controlling an existing generative model with conditioning attributes and their compositions.

Viz-It Data Visualizer Web-Application If I ask you where most of the data wrangler looses their time ? It is Data Overview and EDA. Presenting "Viz-I

NVIDIA Research Projects 66 Jan 01, 2023
Train Yolov4 using NBX-Jobs

yolov4-trainer-nbox Train Yolov4 using NBX-Jobs. Use the powerfull functionality available in nbox-SDK repo to train a tiny-Yolo v4 model on Pascal VO

Yash Bonde 1 Jan 12, 2022
Implementation of Pix2Seq in PyTorch

pix2seq-pytorch Implementation of Pix2Seq paper Different from the paper image input size 1280 bin size 1280 LambdaLR scheduler used instead of Linear

Tony Shin 9 Dec 15, 2022
ComputerVision - This repository aims at realized easy network architecture

ComputerVision This repository aims at realized easy network architecture Colori

DongDong 4 Dec 14, 2022
Lightwood is Legos for Machine Learning.

Lightwood is like Legos for Machine Learning. A Pytorch based framework that breaks down machine learning problems into smaller blocks that can be glu

MindsDB Inc 312 Jan 08, 2023
Code for our ALiBi method for transformer language models.

Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation This repository contains the code and models for our paper Tra

Ofir Press 211 Dec 31, 2022
A quantum game modeling of pandemic (QHack 2022)

Contributors: @JongheumJung, @YoonjaeChung, @GyunghunKim Abstract In the regime of a global pandemic, leaders around the world need to consider variou

Yoonjae Chung 8 Apr 03, 2022
Official implementation of "Not only Look, but also Listen: Learning Multimodal Violence Detection under Weak Supervision" ECCV2020

XDVioDet Official implementation of "Not only Look, but also Listen: Learning Multimodal Violence Detection under Weak Supervision" ECCV2020. The proj

peng 64 Dec 12, 2022
Full body anonymization - Realistic Full-Body Anonymization with Surface-Guided GANs

Realistic Full-Body Anonymization with Surface-Guided GANs This is the official

Håkon Hukkelås 30 Nov 18, 2022
GMFlow: Learning Optical Flow via Global Matching

GMFlow GMFlow: Learning Optical Flow via Global Matching Authors: Haofei Xu, Jing Zhang, Jianfei Cai, Hamid Rezatofighi, Dacheng Tao We streamline the

Haofei Xu 298 Jan 04, 2023
A web application that provides real time temperature and humidity readings of a house.

About A web application which provides real time temperature and humidity readings of a house. If you're interested in the data collected so far click

Ben Thompson 3 Jan 28, 2022
Code for: Gradient-based Hierarchical Clustering using Continuous Representations of Trees in Hyperbolic Space. Nicholas Monath, Manzil Zaheer, Daniel Silva, Andrew McCallum, Amr Ahmed. KDD 2019.

gHHC Code for: Gradient-based Hierarchical Clustering using Continuous Representations of Trees in Hyperbolic Space. Nicholas Monath, Manzil Zaheer, D

Nicholas Monath 35 Nov 16, 2022
Time Series Forecasting with Temporal Fusion Transformer in Pytorch

Forecasting with the Temporal Fusion Transformer Multi-horizon forecasting often contains a complex mix of inputs – including static (i.e. time-invari

Nicolás Fornasari 6 Jan 24, 2022
Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite.

TFlite Ultra Fast Lane Detection Inference Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite. So

Ibai Gorordo 12 Aug 27, 2022
Code and experiments for "Deep Neural Networks for Rank Consistent Ordinal Regression based on Conditional Probabilities"

corn-ordinal-neuralnet This repository contains the orginal model code and experiment logs for the paper "Deep Neural Networks for Rank Consistent Ord

Raschka Research Group 14 Dec 27, 2022
A Benchmark For Measuring Systematic Generalization of Multi-Hierarchical Reasoning

Orchard Dataset This repository contains the code used for generating the Orchard Dataset, as seen in the Multi-Hierarchical Reasoning in Sequences: S

Bill Pung 1 Jun 05, 2022
The final project of "Applying AI to 2D Medical Imaging Data" of "AI for Healthcare" nanodegree - Udacity.

Pneumonia Detection from X-Rays Project Overview In this project, you will apply the skills that you have acquired in this 2D medical imaging course t

Omar Laham 1 Jan 14, 2022