Learning to Prompt for Continual Learning

Overview

Learning to Prompt for Continual Learning (L2P) Official Jax Implementation

L2P is a novel continual learning technique which learns to dynamically prompt a pre-trained model to learn tasks sequentially under different task transitions. Different from mainstream rehearsal-based or architecture-based methods, L2P requires neither a rehearsal buffer nor test-time task identity. L2P can be generalized to various continual learning settings including the most challenging and realistic task-agnostic setting. L2P consistently outperforms prior state-of-the-art methods. Surprisingly, L2P achieves competitive results against rehearsal-based methods even without a rehearsal buffer.

Code is written by Zifeng Wang. Acknowledgement to https://github.com/google-research/nested-transformer.

This is not an officially supported Google product.

Enviroment setup

pip install -r requirements.txt

Getting pretrained ViT model

ViT-B/16 model used in this paper can be downloaded at here.

Instructions on running L2P

We provide the configuration file to train and evaluate L2P on multiple benchmarks in configs.

To run our method on the Split CIFAR-100 dataset (class-incremental setting):

python -m main.py --my_config configs/cifar100_l2p.py --workdir=./cifar100_l2p --my_config.init_checkpoint=<ViT-saved-path/ViT-B_16.npz>

To run our method on the more complex Gaussian Scheduled CIFAR-100 dataset (task-agnostic setting):

python -m main.py --my_config configs/cifar100_gaussian_l2p.py --workdir=./cifar100_gaussian_l2p --my_config.init_checkpoint=<ViT-saved-path/ViT-B_16.npz>

Note: we run our experiments using 8 V100 GPUs or 4 TPUs, and we specify a per device batch size of 16 in the config files. This indicates that we use a total batch size of 128.

Visualize results

We use tensorboard to visualize the result. For example, if the working directory specified to run L2P is workdir=./cifar100_l2p, the command to check result is as follows:

tensorboard --logdir ./cifar100_l2p

Here are the important metrics to keep track of, and their corresponding meanings:

Metric Description
accuracy_n Accuracy of the n-th task
forgetting Average forgetting up until the current task
avg_acc Average evaluation accuracy up until the current task

Cite

@inproceedings{wang2021learning,
  title={Learning to Prompt for Continual Learning},
  author={Zifeng Wang and Zizhao Zhang and Chen-Yu Lee and Han Zhang and Ruoxi Sun and Xiaoqi Ren and Guolong Su and Vincent Perot and Jennifer Dy and Tomas Pfister},
  booktitle={arXiv preprint arXiv:2112.08654},
  year={2021}
}
Code for NeurIPS 2021 paper 'Spatio-Temporal Variational Gaussian Processes'

Spatio-Temporal Variational GPs This repository is the official implementation of the methods in the publication: O. Hamelijnck, W.J. Wilkinson, N.A.

AaltoML 26 Sep 16, 2022
NBEATSx: Neural basis expansion analysis with exogenous variables

NBEATSx: Neural basis expansion analysis with exogenous variables We extend the NBEATS model to incorporate exogenous factors. The resulting method, c

Cristian Challu 100 Dec 31, 2022
Demo notebooks for Qiskit application modules demo sessions (Oct 8 & 15):

qiskit-application-modules-demo-sessions This repo hosts demo notebooks for the Qiskit application modules demo sessions hosted on Qiskit YouTube. Par

Qiskit Community 46 Nov 24, 2022
This is the official PyTorch implementation for "Mesa: A Memory-saving Training Framework for Transformers".

Mesa: A Memory-saving Training Framework for Transformers This is the official PyTorch implementation for Mesa: A Memory-saving Training Framework for

Zhuang AI Group 105 Dec 06, 2022
An ever-growing playground of notebooks showcasing CLIP's impressive zero-shot capabilities.

Playground for CLIP-like models Demo Colab Link GradCAM Visualization Naive Zero-shot Detection Smarter Zero-shot Detection Captcha Solver Changelog 2

Kevin Zakka 101 Dec 30, 2022
Official code for the CVPR 2021 paper "How Well Do Self-Supervised Models Transfer?"

How Well Do Self-Supervised Models Transfer? This repository hosts the code for the experiments in the CVPR 2021 paper How Well Do Self-Supervised Mod

Linus Ericsson 157 Dec 16, 2022
Lava-DL, but with PyTorch-Lightning flavour

Deep learning project seed Use this seed to start new deep learning / ML projects. Built in setup.py Built in requirements Examples with MNIST Badges

Sami BARCHID 4 Oct 31, 2022
A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.

TorchRL Disclaimer This library is not officially released yet and is subject to change. The features are available before an official release so that

Meta Research 860 Jan 07, 2023
Orbivator AI - To Determine which features of data (measurements) are most important for diagnosing breast cancer and find out if breast cancer occurs or not.

Orbivator_AI Breast Cancer Wisconsin (Diagnostic) GOAL To Determine which features of data (measurements) are most important for diagnosing breast can

anurag kumar singh 1 Jan 02, 2022
3D mesh stylization driven by a text input in PyTorch

Text2Mesh [Project Page] Text2Mesh is a method for text-driven stylization of a 3D mesh, as described in "Text2Mesh: Text-Driven Neural Stylization fo

Threedle (University of Chicago) 649 Dec 27, 2022
Unsupervised phone and word segmentation using dynamic programming on self-supervised VQ features.

Unsupervised Phone and Word Segmentation using Vector-Quantized Neural Networks Overview Unsupervised phone and word segmentation on speech data is pe

Herman Kamper 13 Dec 11, 2022
USAD - UnSupervised Anomaly Detection on multivariate time series

USAD - UnSupervised Anomaly Detection on multivariate time series Scripts and utility programs for implementing the USAD architecture. Implementation

116 Jan 04, 2023
Conditional Generative Adversarial Networks (CGAN) for Mobility Data Fusion

This code implements the paper, Kim et al. (2021). Imputing Qualitative Attributes for Trip Chains Extracted from Smart Card Data Using a Conditional Generative Adversarial Network. Transportation Re

Eui-Jin Kim 2 Feb 03, 2022
An open source python library for automated feature engineering

"One of the holy grails of machine learning is to automate more and more of the feature engineering process." ― Pedro Domingos, A Few Useful Things to

alteryx 6.4k Jan 03, 2023
g9.py - Torch interactive graphics

g9.py - Torch interactive graphics A Torch toy in the browser. Demo at https://srush.github.io/g9py/ This is a shameless copy of g9.js, written in Pyt

Sasha Rush 13 Nov 16, 2022
Evaluation Pipeline for our ECCV2020: Journey Towards Tiny Perceptual Super-Resolution.

Journey Towards Tiny Perceptual Super-Resolution Test code for our ECCV2020 paper: https://arxiv.org/abs/2007.04356 Our x4 upscaling pre-trained model

Royson 6 Mar 30, 2022
Time Delayed NN implemented in pytorch

Pytorch Time Delayed NN Time Delayed NN implemented in PyTorch. Usage kernels = [(1, 25), (2, 50), (3, 75), (4, 100), (5, 125), (6, 150)] tdnn = TDNN

Daniil Gavrilov 79 Aug 04, 2022
Making self-supervised learning work on molecules by using their 3D geometry to pre-train GNNs. Implemented in DGL and Pytorch Geometric.

3D Infomax improves GNNs for Molecular Property Prediction Video | Paper We pre-train GNNs to understand the geometry of molecules given only their 2D

Hannes Stärk 95 Dec 30, 2022
Supplementary code for the AISTATS 2021 paper "Matern Gaussian Processes on Graphs".

Matern Gaussian Processes on Graphs This repo provides an extension for gpflow with Matérn kernels, inducing variables and trainable models implemente

41 Dec 17, 2022
How to train a CNN to 99% accuracy on MNIST in less than a second on a laptop

Training a NN to 99% accuracy on MNIST in 0.76 seconds A quick study on how fast you can reach 99% accuracy on MNIST with a single laptop. Our answer

Tuomas Oikarinen 42 Dec 10, 2022