A PyTorch implementation of Sharpness-Aware Minimization for Efficiently Improving Generalization

Overview

sam.pytorch

A PyTorch implementation of Sharpness-Aware Minimization for Efficiently Improving Generalization ( Foret+2020) Paper, Official implementation .

Requirements

  • Python>=3.8
  • PyTorch>=1.7.1

To run the example, you further need

  • homura by pip install -U homura-core==2020.12.0
  • chika by pip install -U chika

Example

python cifar10.py [--optim.name {sam,sgd}] [--model {renst20, wrn28_2}] [--optim.rho 0.05]

Results: Test Accuracy (CIFAR-10)

Model SAM SGD
ResNet-20 93.5 93.2
WRN28-2 95.8 95.4
ResNeXT29 96.4 95.8

SAM needs double forward passes per each update, thus training with SAM is slower than training with SGD. In case of ResNet-20 training, 80 mins vs 50 mins on my environment. Additional options --use_amp --jit_model may slightly accelerates the training.

Usage

SAMSGD can be used as a drop-in replacement of PyTorch optimizers with closures. Also, it is compatible with lr_scheduler and has state_dict and load_state_dict.

from sam import SAMSGD

optimizer = SAMSGD(model.parameters(), lr=1e-1, rho=0.05)

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_f(output, target)
        loss.backward()
        return loss


    loss = optimizer.step(closure)

Citation

@ARTICLE{2020arXiv201001412F,
    author = {{Foret}, Pierre and {Kleiner}, Ariel and {Mobahi}, Hossein and {Neyshabur}, Behnam},
    title = "{Sharpness-Aware Minimization for Efficiently Improving Generalization}",
    year = 2020,
    eid = {arXiv:2010.01412},
    eprint = {2010.01412},
}

@software{sampytorch
    author = {Ryuichiro Hataya},
    titile = {sam.pytorch},
    url    = {https://github.com/moskomule/sam.pytorch},
    year   = {2020}
}
Owner
Ryuichiro Hataya
PhD student at UTokyo and RA at RIKEN AIP / focusing on DL and ML
Ryuichiro Hataya
Official implementation of our paper "LLA: Loss-aware Label Assignment for Dense Pedestrian Detection" in Pytorch.

LLA: Loss-aware Label Assignment for Dense Pedestrian Detection This project provides an implementation for "LLA: Loss-aware Label Assignment for Dens

35 Dec 06, 2022
This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in Eurographics 2021

Deep-Detail-Enhancement-for-Any-Garment Introduction This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in

40 Dec 13, 2022
Certifiable Outlier-Robust Geometric Perception

Certifiable Outlier-Robust Geometric Perception About This repository holds the implementation for certifiably solving outlier-robust geometric percep

83 Dec 31, 2022
Adaptation through prediction: multisensory active inference torque control

Adaptation through prediction: multisensory active inference torque control Submitted to IEEE Transactions on Cognitive and Developmental Systems Abst

Cristian Meo 1 Nov 07, 2022
learned_optimization: Training and evaluating learned optimizers in JAX

learned_optimization: Training and evaluating learned optimizers in JAX learned_optimization is a research codebase for training learned optimizers. I

Google 533 Dec 30, 2022
Data augmentation for NLP, accepted at EMNLP 2021 Findings

AEDA: An Easier Data Augmentation Technique for Text Classification This is the code for the EMNLP 2021 paper AEDA: An Easier Data Augmentation Techni

Akbar Karimi 81 Dec 09, 2022
Machine Learning Time-Series Platform

cesium: Open-Source Platform for Time Series Inference Summary cesium is an open source library that allows users to: extract features from raw time s

632 Dec 26, 2022
Pydantic models for pywttr and aiopywttr.

Pydantic models for pywttr and aiopywttr.

Almaz 2 Dec 08, 2022
ICRA 2021 - Robust Place Recognition using an Imaging Lidar

Robust Place Recognition using an Imaging Lidar A place recognition package using high-resolution imaging lidar. For best performance, a lidar equippe

Tixiao Shan 293 Dec 27, 2022
Official Implementation of CoSMo: Content-Style Modulation for Image Retrieval with Text Feedback

CoSMo.pytorch Official Implementation of CoSMo: Content-Style Modulation for Image Retrieval with Text Feedback, Seungmin Lee*, Dongwan Kim*, Bohyung

Seung Min Lee 54 Dec 08, 2022
Official Implementation of DE-CondDETR and DELA-CondDETR in "Towards Data-Efficient Detection Transformers"

DE-DETRs By Wen Wang, Jing Zhang, Yang Cao, Yongliang Shen, and Dacheng Tao This repository is an official implementation of DE-CondDETR and DELA-Cond

Wen Wang 41 Dec 12, 2022
Deep learning library for solving differential equations and more

DeepXDE Voting on whether we should have a Slack channel for discussion. DeepXDE is a library for scientific machine learning. Use DeepXDE if you need

Lu Lu 1.4k Dec 29, 2022
Ganilla - Official Pytorch implementation of GANILLA

GANILLA We provide PyTorch implementation for: GANILLA: Generative Adversarial Networks for Image to Illustration Translation. Paper Arxiv Updates (Fe

Samet Hi 462 Dec 05, 2022
Finetune the base 64 px GLIDE-text2im model from OpenAI on your own image-text dataset

Finetune the base 64 px GLIDE-text2im model from OpenAI on your own image-text dataset

Clay Mullis 82 Oct 13, 2022
Out-of-Domain Human Mesh Reconstruction via Dynamic Bilevel Online Adaptation

DynaBOA Code repositoty for the paper: Out-of-Domain Human Mesh Reconstruction via Dynamic Bilevel Online Adaptation Shanyan Guan, Jingwei Xu, Michell

197 Jan 07, 2023
Cross-Modal Contrastive Learning for Text-to-Image Generation

Cross-Modal Contrastive Learning for Text-to-Image Generation This repository hosts the open source JAX implementation of XMC-GAN. Setup instructions

Google Research 94 Nov 12, 2022
NAACL2021 - COIL Contextualized Lexical Retriever

COIL Repo for our NAACL paper, COIL: Revisit Exact Lexical Match in Information Retrieval with Contextualized Inverted List. The code covers learning

Luyu Gao 108 Dec 31, 2022
[TPAMI 2021] iOD: Incremental Object Detection via Meta-Learning

Incremental Object Detection via Meta-Learning To appear in an upcoming issue of the IEEE Transactions on Pattern Analysis and Machine Intelligence (T

Joseph K J 66 Jan 04, 2023
A Structured Self-attentive Sentence Embedding

Structured Self-attentive sentence embeddings Implementation for the paper A Structured Self-Attentive Sentence Embedding, which was published in ICLR

Kaushal Shetty 488 Nov 28, 2022
An example project demonstrating how the Autonomous Learning Library can be used to build new reinforcement learning agents.

About This repository shows how Autonomous Learning Library can be used to build new reinforcement learning agents. In particular, it contains a model

Chris Nota 5 Aug 30, 2022