PyTorch code of my WACV 2022 paper Improving Model Generalization by Agreement of Learned Representations from Data Augmentation

Overview

Improving Model Generalization by Agreement of Learned Representations from Data Augmentation (WACV 2022)

Paper

ArXiv

Why it matters?

When data augmentation is applied on an input image, a model is forced to learn invariant features to improve model generalization (Figure 1).

Since data augmentation incurs little overhead, why not generate 2 data augmented images (also known as 2 positive samples) from a given input. Then, force the model to agree on the common invariant features to support the correct label (Figure 2). It turns out that maximizing this agreement further improves model model generalization. We call our method AgMax.

Unlike label smoothing, AgMax consistently improves model accuracy. For example on ImageNet1k for 90 epochs, the ResNet50 performance is as follows:

Data Augmentation Baseline Label Smoothing AgMax (Ours)
Standard 76.4 76.8 76.9
CutOut 76.2 76.5 77.1
MixUp 76.5 76.7 77.6
CutMix 76.3 76.4 77.4
AutoAugment (AA) 76.2 76.2 77.1
CutOut+AA 75.7 75.7 76.6
MixUp+AA 75.9 76.5 77.1
CutMix+AA 75.5 75.5 77.0

The figure below demonstrates consistent improvement across different data augmnentation methods:

Install requirements

pip3 install -r requirements.txt

Train

For example, train ResNet50 with AgMax on 2 GPUs for 90 epochs, SGD with lr=0.1 and multistep learning rate scheduler:

CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=ResNet50-standard-agmax --train \
--multisteplr --dataset=imagenet --epochs=90 --save

Compare the results without AgMax:

CUDA_VISIBLE_DEVICES=0,1 python3 main.py --config=ResNet50-standard --train \
--multisteplr --dataset=imagenet --epochs=90 --save

Test

Using a pre-trained model:

ResNet101 trained with CutMix, AutoAugment and AgMax:

mkdir checkpoints
cd checkpoints
wget https://github.com/roatienza/agmax/releases/download/agmax-0.1.0/imagenet-agmax-mi-ResNet101-cutmix-auto_augment-81.19-mlp-4096.pth
cd ..
python3 main.py --config=ResNet101-auto_augment-cutmix-agmax --eval \
--dataset=imagenet \
--resume imagenet-agmax-mi-ResNet101-cutmix-auto_augment-81.19-mlp-4096.pth

ResNet50 trained with CutMix, AutoAugment and AgMax:

python3 main.py --config=ResNet50-auto_augment-cutmix-agmax --eval --n-units=2048 \
--dataset=imagenet --resume imagenet-agmax-ResNet50-cutmix-auto_augment-79.12-mlp-2048.pth

Other pre-trained models (Baselines):

Citation

If you find this work useful, please cite:

@inproceedings{atienza2022agmax,
  title={Improving Model Generalization by Agreement of Learned Representations from Data Augmentation},
  author={Atienza, Rowel},
  booktitle = {IEEE/CVF Winter Conference on Applications of Computer Vision},
  year={2022},
  pubstate={published},
  tppubtype={inproceedings}
}
You might also like...
Code for the AAAI-2022 paper: Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification

Imagine by Reasoning: A Reasoning-Based Implicit Semantic Data Augmentation for Long-Tailed Classification (AAAI 2022) Prerequisite PyTorch = 1.2.0 P

Code for the paper Relation Prediction as an Auxiliary Training Objective for Improving Multi-Relational Graph Representations (AKBC 2021).
Code for the paper Relation Prediction as an Auxiliary Training Objective for Improving Multi-Relational Graph Representations (AKBC 2021).

Relation Prediction as an Auxiliary Training Objective for Knowledge Base Completion This repo provides the code for the paper Relation Prediction as

Sharpness-Aware Minimization for Efficiently Improving Generalization
Sharpness-Aware Minimization for Efficiently Improving Generalization

Sharpness-Aware-Minimization-TensorFlow This repository provides a minimal implementation of sharpness-aware minimization (SAM) (Sharpness-Aware Minim

ImageNet-CoG is a benchmark for concept generalization. It provides a full evaluation framework for pre-trained visual representations which measure how well they generalize to unseen concepts.

The ImageNet-CoG Benchmark Project Website Paper (arXiv) Code repository for the ImageNet-CoG Benchmark introduced in the paper "Concept Generalizatio

The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022
The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022

DG-TrajGen The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022. Our Meth

Code for
Code for "ShineOn: Illuminating Design Choices for Practical Video-based Virtual Clothing Try-on", accepted at WACV 2021 Generation of Human Behavior Workshop.

ShineOn: Illuminating Design Choices for Practical Video-based Virtual Clothing Try-on [ Paper ] [ Project Page ] This repository contains the code fo

This is the code for the paper "Jinkai Zheng, Xinchen Liu, Wu Liu, Lingxiao He, Chenggang Yan, Tao Mei: Gait Recognition in the Wild with Dense 3D Representations and A Benchmark. (CVPR 2022)"

Gait3D-Benchmark This is the code for the paper "Jinkai Zheng, Xinchen Liu, Wu Liu, Lingxiao He, Chenggang Yan, Tao Mei: Gait Recognition in the Wild

Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector
Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector

HackED 2022 Team 3IQ - 2022 Imposter Detector By Aneeljyot Alagh, Curtis Kan, Jo

[WACV 2020] Reducing Footskate in Human Motion Reconstruction with Ground Contact Constraints

Reducing Footskate in Human Motion Reconstruction with Ground Contact Constraints Official implementation for Reducing Footskate in Human Motion Recon

Comments
  • Question about the mutual information in the agmax loss

    Question about the mutual information in the agmax loss

    The agmax loss is computed by "agreement_loss, dl = agmax_loss(y, target, self.args.dl_weight)" , I do not know what does the argeement_loss and dl mean, and do not know the relationship between these two loss and the calculation of mutual information. Could you give me some help?

    opened by YananGu 4
Owner
Rowel Atienza
Rowel Atienza
Hand Gesture Volume Control is AIML based project which uses image processing to control the volume of your Computer.

Hand Gesture Volume Control Modules There are basically three modules Handtracking Program Handtracking Module Volume Control Program Handtracking Pro

VITTAL 1 Jan 12, 2022
A Probabilistic End-To-End Task-Oriented Dialog Model with Latent Belief States towards Semi-Supervised Learning

LABES This is the code for EMNLP 2020 paper "A Probabilistic End-To-End Task-Oriented Dialog Model with Latent Belief States towards Semi-Supervised L

17 Sep 28, 2022
It's a powerful version of linebot

CTPS-FINAL Linbot-sever.py 主程式 Algorithm.py 推薦演算法,媒合餐廳端資料與顧客端資料 config.ini 儲存 channel-access-token、channel-secret 資料 Preface 生活在成大將近4年,我們每天的午餐時間看著形形色色

1 Oct 17, 2022
A nutritional label for food for thought.

Lexiscore As a first effort in tackling the theme of information overload in content consumption, I've been working on the lexiscore: a nutritional la

Paul Bricman 34 Nov 08, 2022
Fang Zhonghao 13 Nov 19, 2022
Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Phil Wang 180 Jan 05, 2023
PyTorch implementation of PP-LCNet

PP-LCNet-Pytorch Pre-Trained Models Google Drive p018 Accuracy Models Top1 Top5 PPLCNet_x0_25 0.5186 0.7565 PPLCNet_x0_35 0.5809 0.8083 PPLCNet_x0_5 0

24 Dec 12, 2022
A Keras implementation of CapsNet in the paper: Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules

NOTE This implementation is fork of https://github.com/XifengGuo/CapsNet-Keras , applied to IMDB texts reviews dataset. CapsNet-Keras A Keras implemen

Lauro Moraes 5 Oct 23, 2022
Official implement of Paper:A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sening images

A deeply supervised image fusion network for change detection in high resolution bi-temporal remote sensing images 深度监督影像融合网络DSIFN用于高分辨率双时相遥感影像变化检测 Of

Chenxiao Zhang 135 Dec 19, 2022
The Environment I built to study Reinforcement Learning + Pokemon Showdown

pokemon-showdown-rl-environment The Environment I built to study Reinforcement Learning + Pokemon Showdown Been a while since I ran this. Think it is

3 Jan 16, 2022
Kaggle competition: Springleaf Marketing Response

PruebaEnel Prueba Kaggle-Springleaf-master Prueba Kaggle-Springleaf Kaggle competition: Springleaf Marketing Response Competencia de Kaggle: Marketing

1 Feb 09, 2022
Code for "Sparse Steerable Convolutions: An Efficient Learning of SE(3)-Equivariant Features for Estimation and Tracking of Object Poses in 3D Space"

Sparse Steerable Convolution (SS-Conv) Code for "Sparse Steerable Convolutions: An Efficient Learning of SE(3)-Equivariant Features for Estimation and

25 Dec 21, 2022
Benchmark spaces - Benchmarks of how well different two dimensional spaces work for clustering algorithms

benchmark_spaces Benchmarks of how well different two dimensional spaces work fo

Bram Cohen 6 May 07, 2022
A set of Deep Reinforcement Learning Agents implemented in Tensorflow.

Deep Reinforcement Learning Agents This repository contains a collection of reinforcement learning algorithms written in Tensorflow. The ipython noteb

Arthur Juliani 2.2k Jan 01, 2023
WatermarkRemoval-WDNet-WACV2021

WatermarkRemoval-WDNet-WACV2021 Thank you for your attention. Citation Please cite the related works in your publications if it helps your research: @

LUYI 63 Dec 05, 2022
Official repo for QHack—the quantum machine learning hackathon

Note: This repository has been frozen while we consider the submissions for the QHack Open Hackathon. We hope you enjoyed the event! Welcome to QHack,

Xanadu 118 Jan 05, 2023
A python library for self-supervised learning on images.

Lightly is a computer vision framework for self-supervised learning. We, at Lightly, are passionate engineers who want to make deep learning more effi

Lightly 2k Jan 08, 2023
Implementation of " SESS: Self-Ensembling Semi-Supervised 3D Object Detection" (CVPR2020 Oral)

SESS: Self-Ensembling Semi-Supervised 3D Object Detection Created by Na Zhao from National University of Singapore Introduction This repository contai

125 Dec 23, 2022
This repository is an unoffical PyTorch implementation of Medical segmentation in 3D and 2D.

Pytorch Medical Segmentation Read Chinese Introduction:Here! Recent Updates 2021.1.8 The train and test codes are released. 2021.2.6 A bug in dice was

EasyCV-Ellis 618 Dec 27, 2022
Bayesian Generative Adversarial Networks in Tensorflow

Bayesian Generative Adversarial Networks in Tensorflow This repository contains the Tensorflow implementation of the Bayesian GAN by Yunus Saatchi and

Andrew Gordon Wilson 1k Nov 29, 2022