The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

Overview

This repository is the official PyTorch implementation of SAINT. Find the paper on arxiv

SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

Overview

Requirements

We recommend using anaconda or miniconda for python. Our code has been tested with python=3.8 on linux.

To create a new environment with conda

conda create -n saint_env python=3.8
conda activate saint_env

We recommend installing the latest pytorch, torchvision, einops, pandas, wget, sklearn packages.

You can install them using

conda install pytorch torchvision -c pytorch
conda install -c conda-forge einops 
conda install -c conda-forge pandas 
conda install -c conda-forge python-wget 
conda install -c anaconda scikit-learn 

Make sure the following requirements are met

  • torch>=1.8.1
  • torchvision>=0.9.1

Optional

We used wandb to update our logs. But it is optional.

conda install -c conda-forge wandb 

Training & Evaluation

In each of our experiments, we use a single Nvidia GeForce RTX 2080Ti GPU.

First download the processed datasets from this link into the folder ./data

To train the model(s) in the paper, run this command:

python train.py  --dataset <dataset_name> --attentiontype <attention_type> 

Pretraining is useful when there are few training data samples. Sample code looks like this

python train.py  --dataset <dataset_name> --attentiontype <attention_type> --pretrain --pt_tasks <pretraining_task_touse> --pt_aug <augmentations_on_data_touse> --ssl_avail_y <Number_of_labeled_samples>

Train all 16 datasets by running bash files. train.sh for supervised learning and train_pt.sh for pretraining and semi-supervised learning

bash train.sh
bash train_pt.sh

Arguments

  • --dataset : Dataset name. We support only the 16 datasets discussed in the paper. Supported datasets are ['1995_income','bank_marketing','qsar_bio','online_shoppers','blastchar','htru2','shrutime','spambase','philippine','mnist','arcene','volkert','creditcard','arrhythmia','forest','kdd99']
  • --embedding_size : Size of the feature embeddings
  • --transformer_depth : Depth of the model. Number of stages.
  • --attention_heads : Number of attention heads in each Attention layer.
  • --cont_embeddings : Style of embedding continuous data.
  • --attentiontype : Variant of SAINT. 'col' refers to SAINT-s variant, 'row' is SAINT-i, and 'colrow' refers to SAINT.
  • --pretrain : To enable pretraining
  • --pt_tasks : Losses we want to use for pretraining. Multiple arguments can be passed.
  • --pt_aug : Types of data augmentations used in pretraining. Multiple arguments are allowed. We support only mixup and CutMix right now.
  • --ssl_avail_y : Number of labeled samples used in semi-supervised experiments. Default is 0, which means all samples are labeled and is supervised case.
  • --pt_projhead_style : Projection head style used in contrastive pipeline.
  • --nce_temp : Temperature used in contrastive loss function.
  • --active_log : To update the logs onto wandb. This is optional

Evaluation

We choose the best model by evaluating the model on validation dataset. The AUROC(for binary classification datasets) and Accuracy (for multiclass classification datasets) of the best model on test datasets is printed after training is completed. If wandb is enabled, they are logged to 'test_auroc_bestep', 'test_accuracy_bestep' variables.

Acknowledgements

We would like to thank the following public repo from which we borrowed various utilites.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Cite us

@article{somepalli2021saint,
  title={SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training},
  author={Somepalli, Gowthami and Goldblum, Micah and Schwarzschild, Avi and Bruss, C Bayan and Goldstein, Tom},
  journal={arXiv preprint arXiv:2106.01342},
  year={2021}
}

Owner
Gowthami Somepalli
Gowthami Somepalli
Official implementation for (Show, Attend and Distill: Knowledge Distillation via Attention-based Feature Matching, AAAI-2021)

Show, Attend and Distill: Knowledge Distillation via Attention-based Feature Matching Official pytorch implementation of "Show, Attend and Distill: Kn

Clova AI Research 80 Dec 16, 2022
My coursework for Machine Learning (2021 Spring) at National Taiwan University (NTU)

Machine Learning 2021 Machine Learning (NTU EE 5184, Spring 2021) Instructor: Hung-yi Lee Course Website : (https://speech.ee.ntu.edu.tw/~hylee/ml/202

100 Dec 26, 2022
CLIP (Contrastive Language–Image Pre-training) trained on Indonesian data

CLIP-Indonesian CLIP (Radford et al., 2021) is a multimodal model that can connect images and text by training a vision encoder and a text encoder joi

Galuh 17 Mar 10, 2022
Distributing reference energies for SMIRNOFF implementations

Warning: This code is currently experimental and under active development. Is it not yet suitable for distribution or use as reference implementation.

Open Force Field Initiative 1 Dec 07, 2021
A rule-based log analyzer & filter

Flog 一个根据规则集来处理文本日志的工具。 前言 在日常开发过程中,由于缺乏必要的日志规范,导致很多人乱打一通,一个日志文件夹解压缩后往往有几十万行。 日志泛滥会导致信息密度骤减,给排查问题带来了不小的麻烦。 以前都是用grep之类的工具先挑选出有用的,再逐条进行排查,费时费力。在忍无可忍之后决

上山打老虎 9 Jun 23, 2022
Using contrastive learning and OpenAI's CLIP to find good embeddings for images with lossy transformations

The official code for the paper "Inverse Problems Leveraging Pre-trained Contrastive Representations" (to appear in NeurIPS 2021).

Sriram Ravula 26 Dec 10, 2022
Build a small, 3 domain internet using Github pages and Wikipedia and construct a crawler to crawl, render, and index.

TechSEO Crawler Build a small, 3 domain internet using Github pages and Wikipedia and construct a crawler to crawl, render, and index. Play with the r

JR Oakes 57 Nov 24, 2022
OpenMMLab Image Classification Toolbox and Benchmark

Introduction English | 简体中文 MMClassification is an open source image classification toolbox based on PyTorch. It is a part of the OpenMMLab project. D

OpenMMLab 1.8k Jan 03, 2023
Deep Q Learning with OpenAI Gym and Pokemon Showdown

pokemon-deep-learning An openAI gym project for pokemon involving deep q learning. Made by myself, Sam Little, and Layton Webber. This code captures g

2 Dec 22, 2021
Semi-Supervised Semantic Segmentation via Adaptive Equalization Learning, NeurIPS 2021 (Spotlight)

Semi-Supervised Semantic Segmentation via Adaptive Equalization Learning, NeurIPS 2021 (Spotlight) Abstract Due to the limited and even imbalanced dat

Hanzhe Hu 99 Dec 12, 2022
PyTorch implementations of Generative Adversarial Networks.

This repository has gone stale as I unfortunately do not have the time to maintain it anymore. If you would like to continue the development of it as

Erik Linder-Norén 13.4k Jan 08, 2023
Code for paper "Document-Level Argument Extraction by Conditional Generation". NAACL 21'

Argument Extraction by Generation Code for paper "Document-Level Argument Extraction by Conditional Generation". NAACL 21' Dependencies pytorch=1.6 tr

Zoey Li 87 Dec 26, 2022
I3-master-layout - Simple master and stack layout script

Simple master and stack layout script | ------ | ----- | | | | | Ma

Tobias S 18 Dec 05, 2022
Code for the paper: Fighting Fake News: Image Splice Detection via Learned Self-Consistency

Fighting Fake News: Image Splice Detection via Learned Self-Consistency [paper] [website] Minyoung Huh *12, Andrew Liu *1, Andrew Owens1, Alexei A. Ef

minyoung huh (jacob) 174 Dec 09, 2022
SegNet-Basic with Keras

SegNet-Basic: What is Segnet? Deep Convolutional Encoder-Decoder Architecture for Semantic Pixel-wise Image Segmentation Segnet = (Encoder + Decoder)

Yad Konrad 81 Jun 30, 2022
Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal, multi-exposure and multi-focus image fusion.

U2Fusion Code of U2Fusion: a unified unsupervised image fusion network for multiple image fusion tasks, including multi-modal (VIS-IR, medical), multi

Han Xu 129 Dec 11, 2022
[NeurIPS 2021] Galerkin Transformer: a linear attention without softmax

[NeurIPS 2021] Galerkin Transformer: linear attention without softmax Summary A non-numerical analyst oriented explanation on Toward Data Science abou

Shuhao Cao 159 Dec 20, 2022
This code is an implementation for Singing TTS.

MLP Singer This code is an implementation for Singing TTS. The algorithm is based on the following papers: Tae, J., Kim, H., & Lee, Y. (2021). MLP Sin

Heejo You 22 Dec 23, 2022
The repo for reproducing Seed-driven Document Ranking for Systematic Reviews: A Reproducibility Study

ECIR Reproducibility Paper: Seed-driven Document Ranking for Systematic Reviews: A Reproducibility Study This code corresponds to the reproducibility

ielab 3 Mar 31, 2022
This code is part of the reproducibility package for the SANER 2022 paper "Generating Clarifying Questions for Query Refinement in Source Code Search".

Clarifying Questions for Query Refinement in Source Code Search This code is part of the reproducibility package for the SANER 2022 paper "Generating

Zachary Eberhart 0 Dec 04, 2021