Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

Overview

PAWS-TF 🐾

Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS) in TensorFlow (2.4.1).

PAWS introduces a simple way to combine a very small fraction of labeled data with a comparatively larger corpus of unlabeled data during pre-training. With its approach, it sets the state-of-the-art in semi-supervised learning (as of May 2021) beating methods like SimCLRV2, Meta Pseudo Labels that too with fewer parameters and a smaller pre-training schedule. For details, I recommend checking out the original paper as well as this blog post by the authors.

This repository implements and includes all the major bits proposed in PAWS in TensorFlow. The only major difference is that the pre-training and subsequent fine-tuning weren't run for the original number of epochs (600 and 30 respectively) to save compute. I have reused the utility components for PAWS loss from the original implementation.

Dataset ⌗

The current code works with CIFAR10 and uses 4000 labeled samples (8%) during pre-training (along with the unlabeled samples).

Features

  • Multi-crop augmentation strategy (originally introduced in SwAV)
  • Class stratified sampler (common in few-shot classification problems)
  • WarmUpCosine learning rate schedule (which is typical for self-supervised and semi-supervised pre-training)
  • LARS optimizer (comes from TensorFlow Model Garden)

The trunk portion (all, except the last classification layer) of a WideResNet-28-2 is used inside the encoder for CIFAR10. All the experimental configurations were followed from the Appendix C of the paper.

Setup and code structure 💻

A GCP VM (n1-standard-8) with a single V100 GPU was used for executing the code.

  • paws_train.py runs the pre-training as introduced in PAWS.
  • fine_tune.py runs the fine-tuning part as suggested in Appendix C. Note that this is only required for CIFAR10.
  • nn_eval.py runs the soft nearest neighbor classification on CIFAR10 test set.

Pre-training and fine-tuning total take 1.4 hours to complete. All the logs are available in misc/logs.txt. Additionally, the indices that were used to sample the labeled examples from the CIFAR10 training set are available here.

Results 📊

Pre-training

PAWS minimizes the cross-entropy loss (as well as maximizes mean-entropy) during pre-training. This is what the training plot indicates too:

To evaluate the effectivity of the pre-training, PAWS performs soft nearest neighbor classification to report the top-1 accuracy score on a given test set.

Top-1 Accuracy

This repository gets to 73.46% top-1 accuracy on the CIFAR10 test set. Again, note that I only pre-trained for 50 epochs (as opposed to 600) and fine-tuned for 10 epochs (as opposed to 30). With the original schedule this score should be around 96.0%.

In the following PCA projection plot, we see that the embeddings of images (computed after fine-tuning) of PAWS are starting to be well separated:

Notebooks 📘

There are two Colab Notebooks:

Misc ⺟

  • Model weights are available here for reproducibility.
  • With mixed-precision training, the performance can further be improved. I am open to accepting contributions that would implement mixed-precision training in the current code.

Acknowledgements

  • Huge amount of thanks to Mahmoud Assran (first author of PAWS) for patiently resolving my doubts.
  • ML-GDE program for providing GCP credit support.

Paper Citation

@misc{assran2021semisupervised,
      title={Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples}, 
      author={Mahmoud Assran and Mathilde Caron and Ishan Misra and Piotr Bojanowski and Armand Joulin and Nicolas Ballas and Michael Rabbat},
      year={2021},
      eprint={2104.13963},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
You might also like...
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch

alias-free-gan-pytorch Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) This implementation

Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286
Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)
PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)

Asym-Siam: On the Importance of Asymmetry for Siamese Representation Learning This is a PyTorch implementation of the Asym-Siam paper, CVPR 2022: @inp

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).
This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Predicting Patient Outcomes with Graph Representation Learning This repository contains the code used for Predicting Patient Outcomes with Graph Repre

https://arxiv.org/abs/2102.11005
https://arxiv.org/abs/2102.11005

LogME LogME: Practical Assessment of Pre-trained Models for Transfer Learning How to use Just feed the features f and labels y to the function, and yo

Supplementary code for the paper
Supplementary code for the paper "Meta-Solver for Neural Ordinary Differential Equations" https://arxiv.org/abs/2103.08561

Meta-Solver for Neural Ordinary Differential Equations Towards robust neural ODEs using parametrized solvers. Main idea Each Runge-Kutta (RK) solver w

Code for paper "A Critical Assessment of State-of-the-Art in Entity Alignment" (https://arxiv.org/abs/2010.16314)

A Critical Assessment of State-of-the-Art in Entity Alignment This repository contains the source code for the paper A Critical Assessment of State-of

Code for the paper: Learning Adversarially Robust Representations via Worst-Case Mutual Information Maximization (https://arxiv.org/abs/2002.11798)

Representation Robustness Evaluations Our implementation is based on code from MadryLab's robustness package and Devon Hjelm's Deep InfoMax. For all t

ISTR: End-to-End Instance Segmentation with Transformers (https://arxiv.org/abs/2105.00637)

This is the project page for the paper: ISTR: End-to-End Instance Segmentation via Transformers, Jie Hu, Liujuan Cao, Yao Lu, ShengChuan Zhang, Yan Wa

Releases(v1.0.0)
Owner
Sayak Paul
Trying to learn how machines learn.
Sayak Paul
Technical experimentations to beat the stock market using deep learning :chart_with_upwards_trend:

DeepStock Technical experimentations to beat the stock market using deep learning. Experimentations Deep Learning Stock Prediction with Daily News Hea

Keon 449 Dec 29, 2022
Repository for the AugmentedPCA Python package.

Overview This Python package provides implementations of Augmented Principal Component Analysis (AugmentedPCA) - a family of linear factor models that

Billy Carson 6 Dec 07, 2022
PyTorch implementation for our paper Learning Character-Agnostic Motion for Motion Retargeting in 2D, SIGGRAPH 2019

Learning Character-Agnostic Motion for Motion Retargeting in 2D We provide PyTorch implementation for our paper Learning Character-Agnostic Motion for

Rundi Wu 367 Dec 22, 2022
SwinTrack: A Simple and Strong Baseline for Transformer Tracking

SwinTrack This is the official repo for SwinTrack. A Simple and Strong Baseline Prerequisites Environment conda (recommended) conda create -y -n SwinT

LitingLin 196 Jan 04, 2023
A higher performance pytorch implementation of DeepLab V3 Plus(DeepLab v3+)

A Higher Performance Pytorch Implementation of DeepLab V3 Plus Introduction This repo is an (re-)implementation of Encoder-Decoder with Atrous Separab

linhua 326 Nov 22, 2022
VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.

What's New Below we share, in reverse chronological order, the updates and new releases in VISSL. All VISSL releases are available here. [Oct 2021]: V

Meta Research 2.9k Jan 07, 2023
A free, multiplatform SDK for real-time facial motion capture using blendshapes, and rigid head pose in 3D space from any RGB camera, photo, or video.

mocap4face by Facemoji mocap4face by Facemoji is a free, multiplatform SDK for real-time facial motion capture based on Facial Action Coding System or

Facemoji 591 Dec 27, 2022
用强化学习DQN算法,训练AI模型来玩合成大西瓜游戏,提供Keras版本和PARL(paddle)版本

用强化学习玩合成大西瓜 代码地址:https://github.com/Sharpiless/play-daxigua-using-Reinforcement-Learning 用强化学习DQN算法,训练AI模型来玩合成大西瓜游戏,提供Keras版本、PARL(paddle)版本和pytorch版本

72 Dec 17, 2022
Semantically Contrastive Learning for Low-light Image Enhancement

Semantically Contrastive Learning for Low-light Image Enhancement Here, we propose an effective semantically contrastive learning paradigm for Low-lig

48 Dec 16, 2022
Paddle implementation for "Highly Efficient Knowledge Graph Embedding Learning with Closed-Form Orthogonal Procrustes Analysis" (NAACL 2021)

ProcrustEs-KGE Paddle implementation for Highly Efficient Knowledge Graph Embedding Learning with Orthogonal Procrustes Analysis 🙈 A more detailed re

Lincedo Lab 4 Jun 09, 2021
Flower - A Friendly Federated Learning Framework

Flower - A Friendly Federated Learning Framework Flower (flwr) is a framework for building federated learning systems. The design of Flower is based o

Adap 1.8k Jan 01, 2023
Paddle Graph Learning (PGL) is an efficient and flexible graph learning framework based on PaddlePaddle

DOC | Quick Start | 中文 Breaking News !! 🔥 🔥 🔥 OGB-LSC KDD CUP 2021 winners announced!! (2021.06.17) Super excited to announce our PGL team won TWO

1.5k Jan 06, 2023
bespoke tooling for offensive security's Windows Usermode Exploit Dev course (OSED)

osed-scripts bespoke tooling for offensive security's Windows Usermode Exploit Dev course (OSED) Table of Contents Standalone Scripts egghunter.py fin

epi 268 Jan 05, 2023
GAN-STEM-Conv2MultiSlice - Exploring Generative Adversarial Networks for Image-to-Image Translation in STEM Simulation

GAN-STEM-Conv2MultiSlice GAN method to help covert lower resolution STEM images generated by convolution methods to higher resolution STEM images gene

UW-Madison Computational Materials Group 2 Feb 10, 2021
Public repository created to store my custom-made tools for Just Dance (UbiArt Engine)

Woody's Just Dance Tools Public repository created to store my custom-made tools for Just Dance (UbiArt Engine) Development and updates Almost all of

Wodson de Andrade 8 Dec 24, 2022
Dataset and Code for ICCV 2021 paper "Real-world Video Super-resolution: A Benchmark Dataset and A Decomposition based Learning Scheme"

Dataset and Code for RealVSR Real-world Video Super-resolution: A Benchmark Dataset and A Decomposition based Learning Scheme Xi Yang, Wangmeng Xiang,

Xi Yang 92 Jan 04, 2023
Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, arXiv 2021

Hypercorrelation Squeeze for Few-Shot Segmentation This is the implementation of the paper "Hypercorrelation Squeeze for Few-Shot Segmentation" by Juh

Juhong Min 165 Dec 28, 2022
CV backbones including GhostNet, TinyNet and TNT, developed by Huawei Noah's Ark Lab.

CV Backbones including GhostNet, TinyNet, TNT (Transformer in Transformer) developed by Huawei Noah's Ark Lab. GhostNet Code TinyNet Code TNT Code Pyr

HUAWEI Noah's Ark Lab 3k Jan 08, 2023
Optimizing synthesizer parameters using gradient approximation

Optimizing synthesizer parameters using gradient approximation NASH 2021 Hackathon! These are some experiments I conducted during NASH 2021, the Neura

Jordie Shier 10 Feb 10, 2022
End-to-end speech secognition toolkit

End-to-end speech secognition toolkit This is an E2E ASR toolkit modified from Espnet1 (version 0.9.9). This is the official implementation of paper:

Jinchuan Tian 147 Dec 28, 2022