A state-of-the-art semi-supervised method for image recognition

Overview

Mean teachers are better role models

Paper ---- NIPS 2017 poster ---- NIPS 2017 spotlight slides ---- Blog post

By Antti Tarvainen, Harri Valpola (The Curious AI Company)

Approach

Mean Teacher is a simple method for semi-supervised learning. It consists of the following steps:

  1. Take a supervised architecture and make a copy of it. Let's call the original model the student and the new one the teacher.
  2. At each training step, use the same minibatch as inputs to both the student and the teacher but add random augmentation or noise to the inputs separately.
  3. Add an additional consistency cost between the student and teacher outputs (after softmax).
  4. Let the optimizer update the student weights normally.
  5. Let the teacher weights be an exponential moving average (EMA) of the student weights. That is, after each training step, update the teacher weights a little bit toward the student weights.

Our contribution is the last step. Laine and Aila [paper] used shared parameters between the student and the teacher, or used a temporal ensemble of teacher predictions. In comparison, Mean Teacher is more accurate and applicable to large datasets.

Mean Teacher model

Mean Teacher works well with modern architectures. Combining Mean Teacher with ResNets, we improved the state of the art in semi-supervised learning on the ImageNet and CIFAR-10 datasets.

ImageNet using 10% of the labels top-5 validation error
Variational Auto-Encoder [paper] 35.42 ± 0.90
Mean Teacher ResNet-152 9.11 ± 0.12
All labels, state of the art [paper] 3.79
CIFAR-10 using 4000 labels test error
CT-GAN [paper] 9.98 ± 0.21
Mean Teacher ResNet-26 6.28 ± 0.15
All labels, state of the art [paper] 2.86

Implementation

There are two implementations, one for TensorFlow and one for PyTorch. The PyTorch version is probably easier to adapt to your needs, since it follows typical PyTorch idioms, and there's a natural place to add your model and dataset. Let me know if anything needs clarification.

Regarding the results in the paper, the experiments using a traditional ConvNet architecture were run with the TensorFlow version. The experiments using residual networks were run with the PyTorch version.

Tips for choosing hyperparameters and other tuning

Mean Teacher introduces two new hyperparameters: EMA decay rate and consistency cost weight. The optimal value for each of these depends on the dataset, the model, and the composition of the minibatches. You will also need to choose how to interleave unlabeled samples and labeled samples in minibatches.

Here are some rules of thumb to get you started:

  • If you are working on a new dataset, it may be easiest to start with only labeled data and do pure supervised training. Then when you are happy with the architecture and hyperparameters, add mean teacher. The same network should work well, although you may want to tune down regularization such as weight decay that you have used with small data.
  • Mean Teacher needs some noise in the model to work optimally. In practice, the best noise is probably random input augmentations. Use whatever relevant augmentations you can think of: the algorithm will train the model to be invariant to them.
  • It's useful to dedicate a portion of each minibatch for labeled examples. Then the supervised training signal is strong enough early on to train quickly and prevent getting stuck into uncertainty. In the PyTorch examples we have a quarter or a half of the minibatch for the labeled examples and the rest for the unlabeled. (See TwoStreamBatchSampler in Pytorch code.)
  • For EMA decay rate 0.999 seems to be a good starting point.
  • You can use either MSE or KL-divergence as the consistency cost function. For KL-divergence, a good consistency cost weight is often between 1.0 and 10.0. For MSE, it seems to be between the number of classes and the number of classes squared. On small datasets we saw MSE getting better results, but KL always worked pretty well too.
  • It may help to ramp up the consistency cost in the beginning over the first few epochs until the teacher network starts giving good predictions.
  • An additional trick we used in the PyTorch examples: Have two seperate logit layers at the top level. Use one for classification of labeled examples and one for predicting the teacher output. And then have an additional cost between the logits of these two predictions. The intent is the same as with the consistency cost rampup: in the beginning the teacher output may be wrong, so loosen the link between the classification prediction and the consistency cost. (See the --logit-distance-cost argument in the PyTorch implementation.)
Owner
Curious AI
Deep good. Unsupervised better.
Curious AI
A project studying the influence of communication in multi-objective normal-form games

Communication in Multi-Objective Normal-Form Games This repo consists of five different types of agents that we have used in our study of communicatio

Willem Röpke 0 Dec 17, 2021
Pytorch Implementation of Various Point Transformers

Pytorch Implementation of Various Point Transformers Recently, various methods applied transformers to point clouds: PCT: Point Cloud Transformer (Men

Neil You 434 Dec 30, 2022
[Arxiv preprint] Causality-inspired Single-source Domain Generalization for Medical Image Segmentation (code&data-processing pipeline)

Causality-inspired Single-source Domain Generalization for Medical Image Segmentation Arxiv preprint Repository under construction. Might still be bug

Cheng 31 Dec 27, 2022
Channel Pruning for Accelerating Very Deep Neural Networks (ICCV'17)

Channel Pruning for Accelerating Very Deep Neural Networks (ICCV'17)

Yihui He 1k Jan 03, 2023
Code for 2021 NeurIPS --- Towards Multi-Grained Explainability for Graph Neural Networks

ReFine: Multi-Grained Explainability for GNNs We are trying hard to update the code, but it may take a while to complete due to our tight schedule rec

Shirley (Ying-Xin) Wu 47 Dec 16, 2022
Prometheus Exporter for data scraped from datenplattform.darmstadt.de

darmstadt-opendata-exporter Scrapes data from https://datenplattform.darmstadt.de and presents it in the Prometheus Exposition format. Pull requests w

Martin Weinelt 2 Apr 12, 2022
Reproducible research and reusable acyclic workflows in Python. Execute code on HPC systems as if you executed them on your personal computer!

Reproducible research and reusable acyclic workflows in Python. Execute code on HPC systems as if you executed them on your machine! Motivation Would

Joeri Hermans 15 Sep 11, 2022
BTC-Generator - BTC Generator With Python

Что такое BTC-Generator? Это генератор чеков всеми любимого @BTC_BANKER_BOT Для

DoomGod 3 Aug 24, 2022
On Evaluation Metrics for Graph Generative Models

On Evaluation Metrics for Graph Generative Models Authors: Rylee Thompson, Boris Knyazev, Elahe Ghalebi, Jungtaek Kim, Graham Taylor This is the offic

13 Jan 07, 2023
Banglore House Prediction Using Flask Server (Python)

Banglore House Prediction Using Flask Server (Python) 🌐 Links 🌐 📂 Repo In this repository, I've implemented a Machine Learning-based Bangalore Hous

Dhyan Shah 1 Jan 24, 2022
ResNEsts and DenseNEsts: Block-based DNN Models with Improved Representation Guarantees

ResNEsts and DenseNEsts: Block-based DNN Models with Improved Representation Guarantees This repository is the official implementation of the empirica

Kuan-Lin (Jason) Chen 2 Oct 02, 2022
Deep Multi-Magnification Network for multi-class tissue segmentation of whole slide images

Deep Multi-Magnification Network This repository provides training and inference codes for Deep Multi-Magnification Network published here. Deep Multi

Computational Pathology 12 Aug 06, 2022
Pytorch implementation of the paper Time-series Generative Adversarial Networks

TimeGAN-pytorch Pytorch implementation of the paper Time-series Generative Adversarial Networks presented at NeurIPS'19. Jinsung Yoon, Daniel Jarrett

Zhiwei ZHANG 21 Nov 24, 2022
Semi-Autoregressive Transformer for Image Captioning

Semi-Autoregressive Transformer for Image Captioning Requirements Python 3.6 Pytorch 1.6 Prepare data Please use git clone --recurse-submodules to clo

YE Zhou 23 Dec 09, 2022
Automatically creates genre collections for your Plex media

Plex Auto Genres Plex Auto Genres is a simple script that will add genre collection tags to your media making it much easier to search for genre speci

Shane Israel 63 Dec 31, 2022
A PaddlePaddle version image model zoo.

Paddle-Image-Models English | 简体中文 A PaddlePaddle version image model zoo. Install Package Install by pip: $ pip install ppim Install by wheel package

AgentMaker 131 Dec 07, 2022
Credit fraud detection in Python using a Jupyter Notebook

Credit-Fraud-Detection - Credit fraud detection in Python using a Jupyter Notebook , using three classification models (Random Forest, Gaussian Naive Bayes, Logistic Regression) from the sklearn libr

Ali Akram 4 Dec 28, 2021
SysWhispers Shellcode Loader

Shhhloader Shhhloader is a SysWhispers Shellcode Loader that is currently a Work in Progress. It takes raw shellcode as input and compiles a C++ stub

icyguider 630 Jan 03, 2023
(Preprint) Official PyTorch implementation of "How Do Vision Transformers Work?"

(Preprint) Official PyTorch implementation of "How Do Vision Transformers Work?"

xxxnell 656 Dec 30, 2022
Implementation of the GBST block from the Charformer paper, in Pytorch

Charformer - Pytorch Implementation of the GBST (gradient-based subword tokenization) module from the Charformer paper, in Pytorch. The paper proposes

Phil Wang 105 Dec 26, 2022