Code accompanying the paper on "An Empirical Investigation of Domain Generalization with Empirical Risk Minimizers" published at NeurIPS, 2021

Overview

Code for "An Empirical Investigation of Domian Generalization with Empirical Risk Minimizers" (NeurIPS 2021)

Motivation and Introduction

Domain Generalization is a task in machine learning where given a shift in the input data distribution, one is expected to perform well on a test task with a different input data distribution. For example, one might train a digit classifier on MNIST data and ask the model to generalize to predict digits that are rotated by say 30 degrees.

While many approaches have been proposed for this problem, we were intrigued by the results on the DomainBed benchmark which suggested that using the simple, empirical risk minimization (ERM) with a proper hyperparameter sweep leads to performance close to state of the art on Domain Generalization Problems.

What governs the generalization of a trained deep learning model using ERM to a given data distribution? This is the question we seek to answer in our NeurIPS 2021 paper:

An Empirical Investigation of Domain Generalization with Empirical Risk Minimizers. Rama Vedantam, David Lopez-Paz*, David Schwab*.

NeurIPS 2021 (*=Equal Contribution)

This repository contains code used for producing the results in our paper.

Initial Setup

  1. Run source init.sh to install all the dependencies for the project. This will also initialize DomainBed as a submodule for the project

  2. Set requisite paths in setup.sh, and run source setup.sh

Computing Generalization Measures

  • Get set up with the DomainBed codebase and launch a sweep for an initial set of trained models (illustrated below for rotated MNIST dataset):
cd DomainBed/

python -m domainbed.scripts.sweep launch\
       --data_dir=${DOMAINBED_DATA} \
       --output_dir=${DOMAINBED_RUN_DIR}/sweep_fifty_fifty \
       --algorithms=ERM \
       --holdout_fraction=0.5\
       --datasets=RotatedMNIST \
       --n_hparams=1\
       --command_launcher submitit

After this step, we have a set of trained models that we can now look to evaluate and measure. Note that unlike the original domainbed paper we holdout a larger fraction (50%) of the data for evaluation of the measures.

  • Once the sweep finishes, aggregate the different files for use by the domianbed_measures codebase:
python domainbed_measures/write_job_status_file.py \
                --sweep_dir=${DOMAINBED_RUN_DIR}/sweep_fifty_fifty \
                --output_txt="domainbed_measures/scratch/sweep_release.txt"
  • Once this step is complete, we can compute various generalization measures and store them to disk for future analysis using:
SLURM_PARTITION="TO_BE_SET"
python domainbed_measures/compute_gen_correlations.py \
	--algorithm=ERM \
    --job_done_file="domainbed_measures/scratch/sweep_release.txt" \
    --run_dir=${MEASURE_RUN_DIR} \
    --all_measures_one_job \
	--slurm_partition=${SLURM_PARTITION}

Where we utilize slurm on a compute cluster to scale the experiments to thousands of models. If you do not have access to such a cluster with multiple GPUs to parallelize the computation, use --slurm_partition="" above and the code will run on a single GPU (although the results might take a long time to compute!).

  • Finally, once the above code is done, use the following code snippet to aggregate the values of the different generalization measures:
python domainbed_measures/extract_generalization_features.py \
    --run_dir=${MEASURE_RUN_DIR} \
    --sweep_name="_out_ERM_RotatedMNIST"

This step yeilds .csv files where each row corresponds to a given trained model. Each row overall has the following format:

dataset | test_envs | measure 1 | measure 2 | measure 3 | target_err

where:

  • test_envs specifies which environments the model is tested on or equivalently trained on, since the remaining environments are used for training
  • target_err specifies the target error value for regression
  • measure 1 specifies the which measure is being computed, e.g. sharpness or fisher eigen value based measures

In case of the file named, for example, sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv, the validation error within domain wd_out_domain_err is also used as one of the measures and target_err is the out of domain generalization error, and all measures are computed on a held-out set of image inputs from the target domain (for more details see the paper).

Alternatively, in case of the file named, sweeps__out_ERM_RotatedMNIST_canon_False_wd.csv, the target_err is the validation accuracy in domain, and all the measures are computed on the in-distribution held-out images.

  • Using this file one can do a number of interesting regression analyses as reported in the paper for measuring generalization.

For example, to generate the kind of results in Table. 1 of the paper in the joint setting, run the following command options:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --stratified_or_joint="joint"\
    --num_features=2 \
    --fix_one_feature_to_wd

Alternatively, to generate results in the stratified setting, run:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --stratified_or_joint="stratified"\
    --num_features=2 \
    --fix_one_feature_to_wd

Finally, to generate results using a single feature (Alone setting in Table. 1), run:

python domainbed_measures/analyze_results.py \
    --input_csv="${MEASURE_RUN_DIR}/sweeps__out_ERM_RotatedMNIST_canon_False_ood.csv"\
    --num_features=1

Translation of measures from the code to the paper

The following table illustrates all the measures in the paper (Appendix Table. 2) and how they are referred to in the codebase:

Measure Name Code Reference
H-divergence c2st
H-divergence + Source Error c2st_perr
H-divergence MS c2st_per_env
H-divergence MS + Source Error c2st_per_env_perr
H-divergence (train) c2st_train
H-divergence (train) + Source Error c2st_train_perr
H-divergence (train) MS c2st_train_per_env
Entropy-Source or Entropy entropy
Entropy-Target entropy_held_out
Fisher-Eigval-Diff fisher_eigval_sum_diff_ex_75
Fisher-Eigval fisher_eigval_sum_ex_75
Fisher-Align or Fisher (main paper) fisher_eigvec_align_ex_75
HΔH-divergence SS hdh
HΔH-divergence SS + Source Error hdh_perr
HΔH-divergence MS hdh_per_env
HΔH-divergence MS + Source Error hdh_per_env_perr
HΔH-divergence (train) SS hdh_train
HΔH-divergence (train) SS + Source Error hdh_train_perr
Jacobian jacobian_norm
Jacobian Ratio jacobian_norm_relative
Jacobian Diff jacobian_norm_relative_diff
Jacobian Log Ratio jacobian_norm_relative_log_diff
Mixup mixup
Mixup Ratio mixup_relative
Mixup Diff mixup_relative_diff
Mixup Log Ratio mixup_relative_log_diff
MMD-Gaussian mmd_gaussian
MMD-Mean-Cov mmd_mean_cov
L2-Path-Norm. path_norm
Sharpness sharp_mag
H+-divergence SS v_plus_c2st
H+-divergence SS + Source Error v_plus_c2st_perr
H+-divergence MS v_plus_c2st_per_env
H+-divergence MS + Source Error v_plus_c2st_per_env_perr
H+ΔH+-divergence SS v_plus_hdh
H+ΔH+-divergence SS + Source Error v_plus_hdh_perr
H+ΔH+-divergence MS v_plus_hdh_per_env
H+ΔH+-divergence MS + Source Error v_plus_hdh_per_env_perr
Source Error wd_out_domain_err

Acknowledgments

We thank the developers of Decodable Information Bottleneck, Domain Bed and Jonathan Frankle for code we found useful for this project.

License

This source code is released under the Creative Commons Attribution-NonCommercial 4.0 International license, included here.

Owner
Meta Research
Meta Research
XtremeDistil framework for distilling/compressing massive multilingual neural network models to tiny and efficient models for AI at scale

XtremeDistilTransformers for Distilling Massive Multilingual Neural Networks ACL 2020 Microsoft Research [Paper] [Video] Releasing [XtremeDistilTransf

Microsoft 125 Jan 04, 2023
A mini lib that implements several useful functions binding to PyTorch in C++.

Torch-gather A mini library that implements several useful functions binding to PyTorch in C++. What does gather do? Why do we need it? When dealing w

maxwellzh 8 Sep 07, 2022
Lighting the Darkness in the Deep Learning Era: A Survey, An Online Platform, A New Dataset

Lighting the Darkness in the Deep Learning Era: A Survey, An Online Platform, A New Dataset This repository provides a unified online platform, LoLi-P

Chongyi Li 457 Jan 03, 2023
[TNNLS 2021] The official code for the paper "Learning Deep Context-Sensitive Decomposition for Low-Light Image Enhancement"

CSDNet-CSDGAN this is the code for the paper "Learning Deep Context-Sensitive Decomposition for Low-Light Image Enhancement" Environment Preparing pyt

Jiaao Zhang 17 Nov 05, 2022
BlueFog Tutorials

BlueFog Tutorials Welcome to the BlueFog tutorials! In this repository, we've put together a collection of awesome Jupyter notebooks. These notebooks

4 Oct 27, 2021
Orange Chicken: Data-driven Model Generalizability in Crosslinguistic Low-resource Morphological Segmentation

Orange Chicken: Data-driven Model Generalizability in Crosslinguistic Low-resource Morphological Segmentation This repository contains code and data f

Zoey Liu 0 Jan 07, 2022
Supporting code for short YouTube series Neural Networks Demystified.

Neural Networks Demystified Supporting iPython notebooks for the YouTube Series Neural Networks Demystified. I've included formulas, code, and the tex

Stephen 1.3k Dec 23, 2022
Bottleneck Transformers for Visual Recognition

Bottleneck Transformers for Visual Recognition Experiments Model Params (M) Acc (%) ResNet50 baseline (ref) 23.5M 93.62 BoTNet-50 18.8M 95.11% BoTNet-

Myeongjun Kim 236 Jan 03, 2023
[NeurIPS 2021] Deceive D: Adaptive Pseudo Augmentation for GAN Training with Limited Data

Near-Duplicate Video Retrieval with Deep Metric Learning This repository contains the Tensorflow implementation of the paper Near-Duplicate Video Retr

Liming Jiang 238 Nov 25, 2022
DeepAL: Deep Active Learning in Python

DeepAL: Deep Active Learning in Python Python implementations of the following active learning algorithms: Random Sampling Least Confidence [1] Margin

Kuan-Hao Huang 583 Jan 03, 2023
Dataset and Source code of paper 'Enhancing Keyphrase Extraction from Academic Articles with their Reference Information'.

Enhancing Keyphrase Extraction from Academic Articles with their Reference Information Overview Dataset and code for paper "Enhancing Keyphrase Extrac

15 Nov 24, 2022
Parameterized Explainer for Graph Neural Network

PGExplainer This is a Tensorflow implementation of the paper: Parameterized Explainer for Graph Neural Network https://arxiv.org/abs/2011.04573 NeurIP

Dongsheng Luo 89 Dec 12, 2022
Face Recognition & AI Based Smart Attendance Monitoring System.

In today’s generation, authentication is one of the biggest problems in our society. So, one of the most known techniques used for authentication is h

Sagar Saha 1 Jan 14, 2022
Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user who joins your server.

Discord-Protect Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user wh

Tir Omar 2 Oct 28, 2021
Application of K-means algorithm on a music dataset after a dimensionality reduction with PCA

PCA for dimensionality reduction combined with Kmeans Goal The Goal of this notebook is to apply a dimensionality reduction on a big dataset in order

Arturo Ghinassi 0 Sep 17, 2022
Fully Adaptive Bayesian Algorithm for Data Analysis (FABADA) is a new approach of noise reduction methods. In this repository is shown the package developed for this new method based on \citepaper.

Fully Adaptive Bayesian Algorithm for Data Analysis FABADA FABADA is a novel non-parametric noise reduction technique which arise from the point of vi

18 Oct 20, 2022
TensorFlow Tutorial and Examples for Beginners (support TF v1 & v2)

TensorFlow Examples This tutorial was designed for easily diving into TensorFlow, through examples. For readability, it includes both notebooks and so

Aymeric Damien 42.5k Jan 08, 2023
PointNetVLAD: Deep Point Cloud Based Retrieval for Large-Scale Place Recognition, CVPR 2018

PointNetVLAD: Deep Point Cloud Based Retrieval for Large-Scale Place Recognition PointNetVLAD: Deep Point Cloud Based Retrieval for Large-Scale Place

Mikaela Uy 294 Dec 12, 2022
Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network.

face-mask-detection Face Mask Detection is a project to determine whether someone is wearing mask or not, using deep neural network. It contains 3 scr

amirsalar 13 Jan 18, 2022
Listing arxiv - Personalized list of today's articles from ArXiv

Personalized list of today's articles from ArXiv Print and/or send to your gmail

Lilianne Nakazono 5 Jun 17, 2022