Official Implementation of SWAD (NeurIPS 2021)

Related tags

Deep Learningswad
Overview

SWAD: Domain Generalization by Seeking Flat Minima (NeurIPS'21)

Official PyTorch implementation of SWAD: Domain Generalization by Seeking Flat Minima.

Junbum Cha, Sanghyuk Chun, Kyungjae Lee, Han-Cheol Cho, Seunghyun Park, Yunsung Lee, Sungrae Park.

Note that this project is built upon [email protected].

Preparation

Dependencies

pip install -r requirements.txt

Datasets

python -m domainbed.scripts.download --data_dir=/my/datasets/path

Environments

Environment details used for our study.

Python: 3.8.6
PyTorch: 1.7.0+cu92
Torchvision: 0.8.1+cu92
CUDA: 9.2
CUDNN: 7603
NumPy: 1.19.4
PIL: 8.0.1

How to Run

train_all.py script conducts multiple leave-one-out cross-validations for all target domain.

python train_all.py exp_name --dataset PACS --data_dir /my/datasets/path

Experiment results are reported as a table. In the table, the row SWAD indicates out-of-domain accuracy from SWAD. The row SWAD (inD) indicates in-domain validation accuracy.

Example results:

+------------+--------------+---------+---------+---------+---------+
| Selection  | art_painting | cartoon |  photo  |  sketch |   Avg.  |
+------------+--------------+---------+---------+---------+---------+
|   oracle   |   82.245%    | 85.661% | 97.530% | 83.461% | 87.224% |
|    iid     |   87.919%    | 78.891% | 96.482% | 78.435% | 85.432% |
|    last    |   82.306%    | 81.823% | 95.135% | 82.061% | 85.331% |
| last (inD) |   95.807%    | 95.291% | 96.306% | 95.477% | 95.720% |
| iid (inD)  |   97.275%    | 96.619% | 96.696% | 97.253% | 96.961% |
|    SWAD    |   89.750%    | 82.942% | 97.979% | 81.870% | 88.135% |
| SWAD (inD) |   97.713%    | 97.649% | 97.316% | 98.074% | 97.688% |
+------------+--------------+---------+---------+---------+---------+

In this example, the DG performance of SWAD for PACS dataset is 88.135%.

If you set indomain_test option to True, the validation set is splitted to validation and test sets, and the (inD) keys become to indicate in-domain test accuracy.

Reproduce the results of the paper

We provide the instructions to reproduce the main results of the paper, Table 1 and 2. Note that the difference in a detailed environment or uncontrolled randomness may bring a little different result from the paper.

  • PACS
python train_all.py PACS0 --dataset PACS --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py PACS1 --dataset PACS --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py PACS2 --dataset PACS --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir /my/datasets/path
  • VLCS
python train_all.py VLCS0 --dataset VLCS --deterministic --trial_seed 0 --checkpoint_freq 50 --tolerance_ratio 0.2 --data_dir /my/datasets/path
python train_all.py VLCS1 --dataset VLCS --deterministic --trial_seed 1 --checkpoint_freq 50 --tolerance_ratio 0.2 --data_dir /my/datasets/path
python train_all.py VLCS2 --dataset VLCS --deterministic --trial_seed 2 --checkpoint_freq 50 --tolerance_ratio 0.2 --data_dir /my/datasets/path
  • OfficeHome
python train_all.py OH0 --dataset OfficeHome --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py OH1 --dataset OfficeHome --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py OH2 --dataset OfficeHome --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir /my/datasets/path
  • TerraIncognita
python train_all.py TR0 --dataset TerraIncognita --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py TR1 --dataset TerraIncognita --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py TR2 --dataset TerraIncognita --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir /my/datasets/path
  • DomainNet
python train_all.py DN0 --dataset DomainNet --deterministic --trial_seed 0 --checkpoint_freq 500 --data_dir /my/datasets/path
python train_all.py DN1 --dataset DomainNet --deterministic --trial_seed 1 --checkpoint_freq 500 --data_dir /my/datasets/path
python train_all.py DN2 --dataset DomainNet --deterministic --trial_seed 2 --checkpoint_freq 500 --data_dir /my/datasets/path

Main Results

Citation

The paper will be published at NeurIPS 2021.

@inproceedings{cha2021swad,
  title={SWAD: Domain Generalization by Seeking Flat Minima},
  author={Cha, Junbum and Chun, Sanghyuk and Lee, Kyungjae and Cho, Han-Cheol and Park, Seunghyun and Lee, Yunsung and Park, Sungrae},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2021}
}

License

This source code is released under the MIT license, included here.

This project includes some code from DomainBed, also MIT licensed.

Designing a Practical Degradation Model for Deep Blind Image Super-Resolution (ICCV, 2021) (PyTorch) - We released the training code!

Designing a Practical Degradation Model for Deep Blind Image Super-Resolution Kai Zhang, Jingyun Liang, Luc Van Gool, Radu Timofte Computer Vision Lab

Kai Zhang 804 Jan 08, 2023
Implementation of the paper Recurrent Glimpse-based Decoder for Detection with Transformer.

REGO-Deformable DETR By Zhe Chen, Jing Zhang, and Dacheng Tao. This repository is the implementation of the paper Recurrent Glimpse-based Decoder for

Zhe Chen 33 Nov 30, 2022
Differential rendering based motion capture blender project.

TraceArmature Summary TraceArmature is currently a set of python scripts that allow for high fidelity motion capture through the use of AI pose estima

William Rodriguez 4 May 27, 2022
ImageNet Adversarial Image Evaluation

ImageNet Adversarial Image Evaluation This repository contains the code and some materials used in the experimental work presented in the following pa

Utku Ozbulak 11 Dec 26, 2022
Duke Machine Learning Winter School: Computer Vision 2022

mlwscv2002 Welcome to the Duke Machine Learning Winter School: Computer Vision 2022! The MLWS-CV includes 3 hands-on training sessions on implementing

Duke + Data Science (+DS) 9 May 25, 2022
Official Pytorch Implementation of GraphiT

GraphiT: Encoding Graph Structure in Transformers This repository implements GraphiT, described in the following paper: Grégoire Mialon*, Dexiong Chen

Inria Thoth 80 Nov 27, 2022
Conservative and Adaptive Penalty for Model-Based Safe Reinforcement Learning

Conservative and Adaptive Penalty for Model-Based Safe Reinforcement Learning This is the official repository for Conservative and Adaptive Penalty fo

7 Nov 22, 2022
PyTorch code for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

Salesforce 1.3k Dec 31, 2022
Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology Sharon Zhou, Eric Zelikman

Stanford Machine Learning Group 34 Nov 16, 2022
Diverse Object-Scene Compositions For Zero-Shot Action Recognition

Diverse Object-Scene Compositions For Zero-Shot Action Recognition This repository contains the source code for the use of object-scene compositions f

7 Sep 21, 2022
Embeddinghub is a database built for machine learning embeddings.

Embeddinghub is a database built for machine learning embeddings.

Featureform 1.2k Jan 01, 2023
Spatial Intention Maps for Multi-Agent Mobile Manipulation (ICRA 2021)

spatial-intention-maps This code release accompanies the following paper: Spatial Intention Maps for Multi-Agent Mobile Manipulation Jimmy Wu, Xingyua

Jimmy Wu 70 Jan 02, 2023
Python library for science observations from the James Webb Space Telescope

JWST Calibration Pipeline JWST requires Python 3.7 or above and a C compiler for dependencies. Linux and MacOS platforms are tested and supported. Win

Space Telescope Science Institute 386 Dec 30, 2022
Code for CVPR 2021 paper TransNAS-Bench-101: Improving Transferrability and Generalizability of Cross-Task Neural Architecture Search.

TransNAS-Bench-101 This repository contains the publishable code for CVPR 2021 paper TransNAS-Bench-101: Improving Transferrability and Generalizabili

Yawen Duan 17 Nov 20, 2022
Code repository for Semantic Terrain Classification for Off-Road Autonomous Driving

BEVNet Datasets Datasets should be put inside data/. For example, data/semantic_kitti_4class_100x100. Training BEVNet-S Example: cd experiments bash t

(Brian) JoonHo Lee 24 Dec 12, 2022
Code for paper "ASAP-Net: Attention and Structure Aware Point Cloud Sequence Segmentation"

ASAP-Net This project implements ASAP-Net of paper ASAP-Net: Attention and Structure Aware Point Cloud Sequence Segmentation (BMVC2020). Overview We i

Hanwen Cao 26 Aug 25, 2022
PyTorch implementation for "Sharpness-aware Quantization for Deep Neural Networks".

Sharpness-aware Quantization for Deep Neural Networks This is the official repository for our paper: Sharpness-aware Quantization for Deep Neural Netw

Zhuang AI Group 30 Dec 19, 2022
Pop-Out Motion: 3D-Aware Image Deformation via Learning the Shape Laplacian (CVPR 2022)

Pop-Out Motion Pop-Out Motion: 3D-Aware Image Deformation via Learning the Shape Laplacian (CVPR 2022) Jihyun Lee*, Minhyuk Sung*, Hyunjin Kim, Tae-Ky

Jihyun Lee 88 Nov 22, 2022
Temporally Coherent GAN SIGGRAPH project.

TecoGAN This repository contains source code and materials for the TecoGAN project, i.e. code for a TEmporally COherent GAN for video super-resolution

Duc Linh Nguyen 2 Jan 18, 2022
Official PyTorch code for the paper: "Point-Based Modeling of Human Clothing" (ICCV 2021)

Point-Based Modeling of Human Clothing Paper | Project page | Video This is an official PyTorch code repository of the paper "Point-Based Modeling of

Visual Understanding Lab @ Samsung AI Center Moscow 64 Nov 22, 2022