What can linearized neural networks actually say about generalization?

Overview

What can linearized neural networks actually say about generalization?

This is the source code to reproduce the experiments of the NeurIPS 2021 paper "What can linearized neural networks actually say about generalization?" by Guillermo Ortiz-Jimenez, Seyed-Mohsen Moosavi-Dezfooli and Pascal Frossard.

Dependencies

To run the code, please install all its dependencies by running:

$ pip install -r requirements.txt

This assumes that you have access to a Linux machine with an NVIDIA GPU with CUDA>=11.1. Otherwise, please check the instructions to install JAX with your setup in the corresponding repository.

In general, all scripts are parameterized using hydra and their configuration files can be found in the config/ folder.

Experiments

The repository contains code to reproduce the following experiments:

Spectral decomposition of NTK

To generate our new benchmark, consisting on the eigenfunctions of the NTK at initialization, please run the python script compute_ntk.py selecting a desired model (e.g., mlp, lenet or resnet18) and supporting dataset (e.g., cifar10 or mnist). This can be done by running

$ python compute_ntk.py model=lenet data.dataset=cifar10

This script will save the eigenvalues, eigenfunctions and weights of the model under artifacts/eigenfunctions/{data.dataset}/{model}/.

For other configuration options, please consult the configuration file config/compute-ntk/config.yaml.

Warning

Take into account that, for large models, this computation can take very long. For example, it took us two days to compute the full eigenvalue decomposition of the NTK of one randomly initialized ResNet18 using 4 NVIDIA V100 GPUs. The estimation of eigenvectors for the MLP or the LeNet, on the other hand, can be done in a matter of minutes, depending on the number of GPUs available and the selected batch_size

Training on binary eigenfunctions

Once you have estimated the eigenfunctions of the NTK, you should be able to train on any of them. To that end, select the desired label_idx (i.e. eigenfunction index), model and dataset, and run

$ python train_ntk.py label_idx=100 model=lenet data.dataset=cifar10 linearize=False

You can choose to train with the original non-linear network, or its linear approximation by specifying your choice with the flag linearize. For the non-linear models, this script also computes the final alignment of the end NTK with the target function, which it stores under artifacts/eigenfunctions/{data.dataset}/{model}/alignment_plots/

To see the different supported training options, please consult the configuration file config/train-ntk/config.yaml.

Estimation of NADs

We also provide code to compute the NADs of a CNN architecture (e.g., lenet or resnet18) using the alignment with the NTK at initialization. To do so, please run

$ python compute_nads.py model=lenet

This script will save the eigenvalues, NADs and weights of the model under artifacts/nads/{model}/.

For other configuration options, please consult the configuration file config/compute-nads/config.yaml.

Training on linearly separable datasets

Once you have estimated the NADs of a network, you should be able to train on linearly separable datasets with a single NAD as discriminative feature. To that end, select the desired label_idx (i.e. NAD index) and model, and run

$ python train_nads.py label_idx=100 model=lenet linearize=False

You can choose to train with the original non-linear network, or its linear approximation by specifying your choice with the flag linearize.

To see the different supported training options, please consult the configuration file config/train-nads/config.yaml.

Comparison of training dynamics with pretrained NTK

We also provide code to compare the training dynamics of the linearize network at initialization, and after non-linear pretraining, to estimate a particular eigenfunction of the NTK at initialization. To do this, please run

$ python pretrained_ntk_comparison.py label_idx=100 model=lenet data.dataset=cifar10

To see the different supported training options, please consult the configuration file config/pretrained_ntk_comparison/config.yaml.

Training on CIFAR2

Finally, you can train a neural network and its linearize approximation on the binary version of CIFAR10, i.e., CIFAR2. To do this, please run

$ python train_cifar.py model=lenet linearize=False

To see the different supported training options, please consult the configuration file config/binary-cifar/config.yaml.

Reference

If you use this code, please cite the following paper:

@InCollection{Ortiz-JimenezNeurIPS2021,
  title = {What can linearized neural networks actually say about generalization?},
  author = {{Ortiz-Jimenez}, Guillermo and {Moosavi-Dezfooli}, Seyed-Mohsen and Frossard, Pascal},
  booktitle = {Advances in Neural Information Processing Systems 35},
  month = Dec,
  year = {2021}
}
Owner
gortizji
PhD student at EPFL
gortizji
Code for the paper "Unsupervised Contrastive Learning of Sound Event Representations", ICASSP 2021.

Unsupervised Contrastive Learning of Sound Event Representations This repository contains the code for the following paper. If you use this code or pa

Eduardo Fonseca 81 Dec 22, 2022
なりすまし検出(anti-spoof-mn3)のWebカメラ向けデモ

FaceDetection-Anti-Spoof-Demo なりすまし検出(anti-spoof-mn3)のWebカメラ向けデモです。 モデルはPINTO_model_zoo/191_anti-spoof-mn3からONNX形式のモデルを使用しています。 Requirement mediapipe

KazuhitoTakahashi 8 Nov 18, 2022
PyTorch implementation of GLOM

GLOM PyTorch implementation of GLOM, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attent

Yeonwoo Sung 20 Aug 17, 2022
A repository for interferometer controller code.

dses-interferometer-controller A repository for interferometer controller code, hardware, and simulations. See dses.science for more information on th

Eli Reed 1 Jan 17, 2022
Adjust Decision Boundary for Class Imbalanced Learning

Adjusting Decision Boundary for Class Imbalanced Learning This repository is the official PyTorch implementation of WVN-RS, introduced in Adjusting De

Peyton Byungju Kim 16 Jan 04, 2023
A Pytorch implementation of MoveNet from Google. Include training code and pre-train model.

Movenet.Pytorch Intro MoveNet is an ultra fast and accurate model that detects 17 keypoints of a body. This is A Pytorch implementation of MoveNet fro

Mr.Fire 241 Dec 26, 2022
A Joint Video and Image Encoder for End-to-End Retrieval

Frozen️ in Time ❄️ ️️️️ ⏳ A Joint Video and Image Encoder for End-to-End Retrieval project page | arXiv | webvid-data Repository containing the code,

225 Dec 25, 2022
Tensorflow Implementation of ECCV'18 paper: Multimodal Human Motion Synthesis

MT-VAE for Multimodal Human Motion Synthesis This is the code for ECCV 2018 paper MT-VAE: Learning Motion Transformations to Generate Multimodal Human

Xinchen Yan 36 Oct 02, 2022
Pyeventbus: a publish/subscribe event bus

pyeventbus pyeventbus is a publish/subscribe event bus for Python 2.7. simplifies the communication between python classes decouples event senders and

15 Apr 21, 2022
AIR^2 for Interaction Prediction

This is the repository for AIR^2 for Interaction Prediction. Explanation of the solution: Video: link License AIR is released under the Apache 2.0 lic

21 Sep 27, 2022
Deep Federated Learning for Autonomous Driving

FADNet: Deep Federated Learning for Autonomous Driving Abstract Autonomous driving is an active research topic in both academia and industry. However,

AIOZ AI 12 Dec 01, 2022
Visual odometry package based on hardware-accelerated NVIDIA Elbrus library with world class quality and performance.

Isaac ROS Visual Odometry This repository provides a ROS2 package that estimates stereo visual inertial odometry using the Isaac Elbrus GPU-accelerate

NVIDIA Isaac ROS 343 Jan 03, 2023
Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition

Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition How Fast Compare to Other Zero-Shot NAS Proxies on CIFAR-10/100 Pre-trained Model

190 Dec 29, 2022
[AI6101] Introduction to AI & AI Ethics is a core course of MSAI, SCSE, NTU, Singapore

[AI6101] Introduction to AI & AI Ethics is a core course of MSAI, SCSE, NTU, Singapore. The repository corresponds to the AI6101 of Semester 1, AY2021-2022, starting from 08/2021. The instructors of

AccSrd 1 Sep 22, 2022
Codes and pretrained weights for winning submission of 2021 Brain Tumor Segmentation (BraTS) Challenge

Winning submission to the 2021 Brain Tumor Segmentation Challenge This repo contains the codes and pretrained weights for the winning submission to th

94 Dec 28, 2022
This repository contains the code for using the H3DS dataset introduced in H3D-Net: Few-Shot High-Fidelity 3D Head Reconstruction

H3DS Dataset This repository contains the code for using the H3DS dataset introduced in H3D-Net: Few-Shot High-Fidelity 3D Head Reconstruction Access

Crisalix 72 Dec 10, 2022
Baseline and template code for node21 detection track

Nodule Detection Algorithm This codebase implements a baseline model, Faster R-CNN, for the nodule detection track in NODE21. It contains all necessar

node21challenge 11 Jan 15, 2022
ShapeGlot: Learning Language for Shape Differentiation

ShapeGlot: Learning Language for Shape Differentiation Created by Panos Achlioptas, Judy Fan, Robert X.D. Hawkins, Noah D. Goodman, Leonidas J. Guibas

Panos 32 Dec 23, 2022
Camera Distortion-aware 3D Human Pose Estimation in Video with Optimization-based Meta-Learning

Camera Distortion-aware 3D Human Pose Estimation in Video with Optimization-based Meta-Learning This is the official repository of "Camera Distortion-

Hanbyel Cho 12 Oct 06, 2022
A framework for joint super-resolution and image synthesis, without requiring real training data

SynthSR This repository contains code to train a Convolutional Neural Network (CNN) for Super-resolution (SR), or joint SR and data synthesis. The met

83 Jan 01, 2023