A state-of-the-art semi-supervised method for image recognition

Overview

Mean teachers are better role models

Paper ---- NIPS 2017 poster ---- NIPS 2017 spotlight slides ---- Blog post

By Antti Tarvainen, Harri Valpola (The Curious AI Company)

Approach

Mean Teacher is a simple method for semi-supervised learning. It consists of the following steps:

  1. Take a supervised architecture and make a copy of it. Let's call the original model the student and the new one the teacher.
  2. At each training step, use the same minibatch as inputs to both the student and the teacher but add random augmentation or noise to the inputs separately.
  3. Add an additional consistency cost between the student and teacher outputs (after softmax).
  4. Let the optimizer update the student weights normally.
  5. Let the teacher weights be an exponential moving average (EMA) of the student weights. That is, after each training step, update the teacher weights a little bit toward the student weights.

Our contribution is the last step. Laine and Aila [paper] used shared parameters between the student and the teacher, or used a temporal ensemble of teacher predictions. In comparison, Mean Teacher is more accurate and applicable to large datasets.

Mean Teacher model

Mean Teacher works well with modern architectures. Combining Mean Teacher with ResNets, we improved the state of the art in semi-supervised learning on the ImageNet and CIFAR-10 datasets.

ImageNet using 10% of the labels top-5 validation error
Variational Auto-Encoder [paper] 35.42 ± 0.90
Mean Teacher ResNet-152 9.11 ± 0.12
All labels, state of the art [paper] 3.79
CIFAR-10 using 4000 labels test error
CT-GAN [paper] 9.98 ± 0.21
Mean Teacher ResNet-26 6.28 ± 0.15
All labels, state of the art [paper] 2.86

Implementation

There are two implementations, one for TensorFlow and one for PyTorch. The PyTorch version is probably easier to adapt to your needs, since it follows typical PyTorch idioms, and there's a natural place to add your model and dataset. Let me know if anything needs clarification.

Regarding the results in the paper, the experiments using a traditional ConvNet architecture were run with the TensorFlow version. The experiments using residual networks were run with the PyTorch version.

Tips for choosing hyperparameters and other tuning

Mean Teacher introduces two new hyperparameters: EMA decay rate and consistency cost weight. The optimal value for each of these depends on the dataset, the model, and the composition of the minibatches. You will also need to choose how to interleave unlabeled samples and labeled samples in minibatches.

Here are some rules of thumb to get you started:

  • If you are working on a new dataset, it may be easiest to start with only labeled data and do pure supervised training. Then when you are happy with the architecture and hyperparameters, add mean teacher. The same network should work well, although you may want to tune down regularization such as weight decay that you have used with small data.
  • Mean Teacher needs some noise in the model to work optimally. In practice, the best noise is probably random input augmentations. Use whatever relevant augmentations you can think of: the algorithm will train the model to be invariant to them.
  • It's useful to dedicate a portion of each minibatch for labeled examples. Then the supervised training signal is strong enough early on to train quickly and prevent getting stuck into uncertainty. In the PyTorch examples we have a quarter or a half of the minibatch for the labeled examples and the rest for the unlabeled. (See TwoStreamBatchSampler in Pytorch code.)
  • For EMA decay rate 0.999 seems to be a good starting point.
  • You can use either MSE or KL-divergence as the consistency cost function. For KL-divergence, a good consistency cost weight is often between 1.0 and 10.0. For MSE, it seems to be between the number of classes and the number of classes squared. On small datasets we saw MSE getting better results, but KL always worked pretty well too.
  • It may help to ramp up the consistency cost in the beginning over the first few epochs until the teacher network starts giving good predictions.
  • An additional trick we used in the PyTorch examples: Have two seperate logit layers at the top level. Use one for classification of labeled examples and one for predicting the teacher output. And then have an additional cost between the logits of these two predictions. The intent is the same as with the consistency cost rampup: in the beginning the teacher output may be wrong, so loosen the link between the classification prediction and the consistency cost. (See the --logit-distance-cost argument in the PyTorch implementation.)
Owner
Curious AI
Deep good. Unsupervised better.
Curious AI
Nvdiffrast - Modular Primitives for High-Performance Differentiable Rendering

Nvdiffrast – Modular Primitives for High-Performance Differentiable Rendering Modular Primitives for High-Performance Differentiable Rendering Samuli

NVIDIA Research Projects 675 Jan 06, 2023
Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology Sharon Zhou, Eric Zelikman

Stanford Machine Learning Group 34 Nov 16, 2022
Delving into Localization Errors for Monocular 3D Object Detection, CVPR'2021

Delving into Localization Errors for Monocular 3D Detection By Xinzhu Ma, Yinmin Zhang, Dan Xu, Dongzhan Zhou, Shuai Yi, Haojie Li, Wanli Ouyang. Intr

XINZHU.MA 124 Jan 04, 2023
Disentangled Lifespan Face Synthesis

Disentangled Lifespan Face Synthesis Project Page | Paper Demo on Colab Preparation Please follow this github to prepare the environments and dataset.

何森 50 Sep 20, 2022
This is a simple backtesting framework to help you test your crypto currency trading. It includes a way to download and store historical crypto data and to execute a trading strategy.

You can use this simple crypto backtesting script to ensure your trading strategy is successful Minimal setup required and works well with static TP a

Andrei 154 Sep 12, 2022
Code and Resources for the Transformer Encoder Reasoning Network (TERN)

Transformer Encoder Reasoning Network Code for the cross-modal visual-linguistic retrieval method from "Transformer Reasoning Network for Image-Text M

Nicola Messina 53 Dec 30, 2022
PRTR: Pose Recognition with Cascade Transformers

PRTR: Pose Recognition with Cascade Transformers Introduction This repository is the official implementation for Pose Recognition with Cascade Transfo

mlpc-ucsd 133 Dec 30, 2022
Hunt down social media accounts by username across social networks

Hunt down social media accounts by username across social networks Installation | Usage | Docker Notes | Contributing Installation # clone the repo $

1 Dec 14, 2021
[NeurIPS 2021] A weak-shot object detection approach by transferring semantic similarity and mask prior.

[NeurIPS 2021] A weak-shot object detection approach by transferring semantic similarity and mask prior.

BCMI 49 Jul 27, 2022
SegTransVAE: Hybrid CNN - Transformer with Regularization for medical image segmentation

SegTransVAE: Hybrid CNN - Transformer with Regularization for medical image segmentation This repo is the official implementation for SegTransVAE. Seg

Nguyen Truong Hai 4 Aug 04, 2022
Nested cross-validation is necessary to avoid biased model performance in embedded feature selection in high-dimensional data with tiny sample sizes

Pruner for nested cross-validation - Sphinx-Doc Nested cross-validation is necessary to avoid biased model performance in embedded feature selection i

1 Dec 15, 2021
DNA sequence classification by Deep Neural Network

DNA sequence classification by Deep Neural Network: Project Overview worked on the DNA sequence classification problem where the input is the DNA sequ

Mohammed Jawwadul Islam Fida 0 Aug 02, 2022
McGill Physics Hackathon 2021: Reaction-Diffusion Models for the Generation of Biological Patterns

DiffuseAnimals: Reaction-Diffusion Models for the Generation of Biological Patterns Introduction Reaction-diffusion equations can be utilized in order

Austin Szuminsky 2 Mar 07, 2022
Person Re-identification

Person Re-identification Final project of Computer Vision Table of content Person Re-identification Table of content Students: Proposed method Dataset

Nguyễn Hoàng Quân 4 Jun 17, 2021
Bayesian Meta-Learning Through Variational Gaussian Processes

vmgp This is the repository of Vivek Myers and Nikhil Sardana for our CS 330 final project, Bayesian Meta-Learning Through Variational Gaussian Proces

Vivek Myers 2 Nov 17, 2022
JupyterNotebook - C/C++, Javascript, HTML, LaTex, Shell scripts in Jupyter Notebook Also run them on remote computer

JupyterNotebook Read, write and execute C, C++, Javascript, Shell scripts, HTML, LaTex in jupyter notebook, And also execute them on remote computer R

1 Jan 09, 2022
[ICCV 2021] Official Pytorch implementation for Discriminative Region-based Multi-Label Zero-Shot Learning SOTA results on NUS-WIDE and OpenImages

Discriminative Region-based Multi-Label Zero-Shot Learning (ICCV 2021) [arXiv][Project page coming soon] Sanath Narayan*, Akshita Gupta*, Salman Kh

Akshita Gupta 54 Nov 21, 2022
Split your patch similarly to `git add -p` but supporting multiple buckets

split-patch.py This is git add -p on steroids for patches. Given a my.patch you can run ./split-patch.py my.patch You can choose in which bucket to p

102 Oct 06, 2022
Official implementation for the paper "Attentive Prototypes for Source-free Unsupervised Domain Adaptive 3D Object Detection"

Attentive Prototypes for Source-free Unsupervised Domain Adaptive 3D Object Detection PyTorch code release of the paper "Attentive Prototypes for Sour

Deepti Hegde 23 Oct 17, 2022
[CVPR 2022] Deep Equilibrium Optical Flow Estimation

Deep Equilibrium Optical Flow Estimation This is the official repo for the paper Deep Equilibrium Optical Flow Estimation (CVPR 2022), by Shaojie Bai*

CMU Locus Lab 136 Dec 18, 2022