PyTorch META-DATASET (Few-shot classification benchmark)

Overview

PyTorch META-DATASET (Few-shot classification benchmark)

This repo contains a PyTorch implementation of meta-dataset and a unified implementation of some few-shot methods. This repo may be useful to you if you:

  • want some pre-trained ImageNet models in PyTorch for META-DATASET;
  • want to benchmark your method on META-DATASET (but do not want to mix your PyTorch code with the original TensorFlow implementation);
  • are looking for a codebase to visualize few-shot episodes.

Benefits over original code:

  1. This repo can be properly seeded, allowing to repeat the same random series of episodes if needed;
  2. Data shuffling is performed without using a buffer, hence reducing the memory consumption;
  3. Better results can be obtained using this repo thanks to an enhanced way of resizing images. More details in the paper.

Note that this code also includes the original implementation for comparison (using the PyTorch workaround proposed by the authors). If you wish to use the original implementation, set the option loader_version: 'tf' in base.yaml (by default set to pytorch).

Yet to do:

  1. Add more methods
  2. Test for the multi-source setting

Table of contents

1. Setting up

Please carefully follow the instructions below to get started.

1.1 Requirements

The present code was developped and tested in Python 3.8. The list of requirements is provided in requirements.txt:

pip install -r requirements.txt

1.2 Data

To download the META-DATASET, please follow the details instructions provided at meta-dataset to obtain the .tfrecords converted data. Once done, make sure all converted dataset are in a single folder, and execute the following script to produce index files:

bash scripts/make_records/make_index_files.sh <path_to_converted_data>

This may take a few minutes. Once all this is done, set the path variable in config/base.yaml to your data folder.

1.3 Download pre-trained models

We provide trained Resnet-18 and WRN-2810 models on the training split of ILSVRC_2012 at checkpoints. All non-episodic baselines use the same checkpoint, stored in the standard folder. The results (averaged over 600 episodes) obtained with the provided Resnet-18 are summarized below:

Inductive methods Architecture ILSVRC Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower Traffic Signs MSCOCO Mean
Finetune Resnet-18 59.8 60.5 63.5 80.6 80.9 61.5 45.2 91.1 55.1 41.8 64.0
ProtoNet Resnet-18 48.2 46.7 44.6 53.8 70.3 45.1 38.5 82.4 42.2 38.0 51.0
SimpleShot Resnet-18 60.0 54.2 55.9 78.6 77.8 57.4 49.2 90.3 49.6 44.2 61.7
Transductive methods Architecture ILSVRC Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower Traffic Signs MSCOCO Mean
BD-CSPN Resnet-18 60.5 54.4 55.2 80.9 77.9 57.3 50.0 91.7 47.8 43.9 62.0
TIM-GD Resnet-18 63.6 65.6 66.4 85.6 84.7 65.8 57.5 95.6 65.2 50.9 70.1

See Sect. 1.4 and 1.5 to reproduce these results.

1.4 Train models from scratch (optional)

In order to train you model from scratch, execute scripts/train.sh script:

bash scripts/train.sh <method> <architecture> <dataset>

method is to be chosen among all method specific config files in config/, architecture in ['resnet18', 'wideres2810'] and dataset among all datasets (as named by the META-DATASET converted folders). Note that the hierarchy of arguments passed to src/train.py and src/eval.py is the following: base_config < method_config < opts arguments.

Mutiprocessing : This code supports distributed training. To leverage this feature, set the gpus option accordingly (for instance gpus: [0, 1, 2, 3]).

1.5 Test your models

Once trained (or once pre-trained models downloaded), you can evaluate your model on the test split of each dataset by running:

bash scripts/test.sh <method> <architecture> <base_dataset> <test_dataset>

Results will be saved in results/ / where corresponds to a unique hash number of the config (you can only get the same result folder iff all hyperparameters are the same).

2. Visualization of results

2.1 Training metrics

During training, training loss and validation accuracy are recorded and saved as .npy files in the checkpoint folder. Then, you can use the src/plot.py to plot these metrics (even during training).

Example 1: Plot the metrics of the standard (=non episodic) resnet-18 on ImageNet:

python src/plot.py --folder checkpoints/ilsvrc_2012/ilsvrc_2012/resnet18/standard/

Example 2: Plot the metrics of all Resnet-18 trained on ImageNet

python src/plot.py --folder checkpoints/ilsvrc_2012/ilsvrc_2012/resnet18/

2.2 Inference metrics

For methods that perform test-time optimization (for instance MAML, TIM, Finetune, ...), method specific metrics are plotted in real-time (versus test iterations) and averaged over test epidodes, which can allow you to track unexpected behavior easily. Such metrics are implemented in src/metrics/, and the choice of which metric to plot is specificied through the eval_metrics option in the method .yaml config file. An example with TIM method is provided below.

2.3 Visualization of episodes

By setting the option visu: True at inference, you can visualize samples of episodes. An example of such visualization is given below:

The samples will be saved in results/. All relevant optons can be found in the base.yaml file, in the EVAL-VISU section.

3. Incorporate your own method

This code was designed to allow easy incorporation of new methods.

Step 1: Add your method .py file to src/methods/ by following the template provided in src/methods/method.py.

Step 2: Add import in src/methods/__init__.py

Step 3: Add your method .yaml config file including the required options episodic_training and method (name of the class corresponding to your method). Also make sure that if your method performs test-time optimization, you also properly set the option iter that specifies the number of optimization steps performed at inference (this argument is also used to plot the inference metrics, see section 2.2).

4. Contributions

Contributions are more than welcome. In particular, if you want to add methods/pre-trained models, do make a pull-request.

5. Citation

If you find this repo useful for your research, please consider citing the following papers:

@misc{boudiaf2021mutualinformation,
      title={Mutual-Information Based Few-Shot Classification}, 
      author={Malik Boudiaf and Ziko Imtiaz Masud and Jérôme Rony and Jose Dolz and Ismail Ben Ayed and Pablo Piantanida},
      year={2021},
      eprint={2106.12252},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Additionally, do not hesitate to file issues if you encounter problems, or reach out directly to Malik Boudiaf ([email protected]).

6. Acknowledgments

I thank the authors of meta-dataset for releasing their code and the author of open-source TFRecord reader for open sourcing an awesome Pytorch-compatible TFRecordReader ! Also big thanks to @hkervadec for his thorough code review !

Owner
Malik Boudiaf
Malik Boudiaf
Re-TACRED: Addressing Shortcomings of the TACRED Dataset

Re-TACRED Re-TACRED: Addressing Shortcomings of the TACRED Dataset

George Stoica 40 Dec 10, 2022
Awesome-AI-books - Some awesome AI related books and pdfs for learning and downloading

Awesome AI books Some awesome AI related books and pdfs for downloading and learning. Preface This repo only used for learning, do not use in business

luckyzhou 1k Jan 01, 2023
Source code for "UniRE: A Unified Label Space for Entity Relation Extraction.", ACL2021.

UniRE Source code for "UniRE: A Unified Label Space for Entity Relation Extraction.", ACL2021. Requirements python: 3.7.6 pytorch: 1.8.1 transformers:

Wang Yijun 109 Nov 29, 2022
Deep Markov Factor Analysis (NeurIPS2021)

Deep Markov Factor Analysis (DMFA) Codes and experiments for deep Markov factor analysis (DMFA) model accepted for publication at NeurIPS2021: A. Farn

Sarah Ostadabbas 2 Dec 16, 2022
Contrastive Learning for Metagenomic Binning

CLMB A simple framework for CLMB - a novel deep Contrastive Learningfor Metagenomic Binning Created by Pengfei Zhang, senior of Department of Computer

1 Sep 14, 2022
Exploiting Robust Unsupervised Video Person Re-identification

Exploiting Robust Unsupervised Video Person Re-identification Implementation of the proposed uPMnet. For the preprint, please refer to [Arxiv]. Gettin

1 Apr 09, 2022
A library of scripts that interact with the PythonTurtle module to create games, drawings, and more

TurtleLib TurtleLib is a library of scripts that interact with the PythonTurtle module to create games, drawings, and more! Using the Scripts Copy or

1 Jan 15, 2022
Official codebase for Decision Transformer: Reinforcement Learning via Sequence Modeling.

Decision Transformer Lili Chen*, Kevin Lu*, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas†, and Igor M

Kevin Lu 1.4k Jan 07, 2023
DeepFaceLive - Live Deep Fake in python, Real-time face swap for PC streaming or video calls

DeepFaceLive - Live Deep Fake in python, Real-time face swap for PC streaming or video calls

8.3k Dec 31, 2022
Generating Fractals on Starknet with Cairo

StarknetFractals Generating the mandelbrot set on Starknet Current Implementation generates 1 pixel of the fractal per call(). It takes a few minutes

Orland0x 10 Jul 16, 2022
Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Infinitely Deep Bayesian Neural Networks with SDEs This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stocha

Winnie Xu 95 Nov 26, 2021
Source Code for ICSE 2022 Paper - ``Can We Achieve Fairness Using Semi-Supervised Learning?''

Fair-SSL Source Code for ICSE 2022 Paper - Can We Achieve Fairness Using Semi-Supervised Learning? Ethical bias in machine learning models has become

1 Dec 18, 2021
Unadversarial Examples: Designing Objects for Robust Vision

Unadversarial Examples: Designing Objects for Robust Vision This repository contains the code necessary to replicate the major results of our paper: U

Microsoft 93 Nov 28, 2022
Implements Gradient Centralization and allows it to use as a Python package in TensorFlow

Gradient Centralization TensorFlow This Python package implements Gradient Centralization in TensorFlow, a simple and effective optimization technique

Rishit Dagli 101 Nov 01, 2022
Predict halo masses from simulations via graph neural networks

HaloGraphNet Predict halo masses from simulations via Graph Neural Networks. Given a dark matter halo and its galaxies, creates a graph with informati

Pablo Villanueva Domingo 20 Nov 15, 2022
Gradient Inversion with Generative Image Prior

Gradient Inversion with Generative Image Prior This repository is an implementation of "Gradient Inversion with Generative Image Prior", accepted to N

MLLab @ Postech 25 Jan 09, 2023
PyTorch implementation of the paper: "Preference-Adaptive Meta-Learning for Cold-Start Recommendation", IJCAI, 2021.

PAML PyTorch implementation of the paper: "Preference-Adaptive Meta-Learning for Cold-Start Recommendation", IJCAI, 2021. (Continuously updating ) Int

15 Nov 18, 2022
Source code for 2021 ICCV paper "In-the-Wild Single Camera 3D Reconstruction Through Moving Water Surfaces"

In-the-Wild Single Camera 3D Reconstruction Through Moving Water Surfaces This is the PyTorch implementation for 2021 ICCV paper "In-the-Wild Single C

27 Dec 06, 2022
Implementation for paper "STAR: A Structure-aware Lightweight Transformer for Real-time Image Enhancement" (ICCV 2021).

STAR-pytorch Implementation for paper "STAR: A Structure-aware Lightweight Transformer for Real-time Image Enhancement" (ICCV 2021). CVF (pdf) STAR-DC

43 Dec 21, 2022
pytorch bert intent classification and slot filling

pytorch_bert_intent_classification_and_slot_filling 基于pytorch的中文意图识别和槽位填充 说明 基本思路就是:分类+序列标注(命名实体识别)同时训练。 使用的预训练模型:hugging face上的chinese-bert-wwm-ext 依

西西嘛呦 33 Dec 15, 2022