Contrastive Learning Inverts the Data Generating Process

Overview

Contrastive Learning Inverts the Data Generating Process

Official code to reproduce the results and data presented in the paper Contrastive Learning Inverts the Data Generating Process.

3DIdent dataset example images

Experiments

To reproduce the disentanglement results for the MLP mixing, use the main_mlp.py script. For the experiments on KITTI Masks use the main_kitti.py script. For those on 3DIdent, use main_3dident.py.

MLP Mixing

> python main_mlp.py --help
usage: main_mlp.py
       [-h] [--sphere-r SPHERE_R] [--box-min BOX_MIN] [--box-max BOX_MAX]
       [--sphere-norm] [--box-norm] [--only-supervised] [--only-unsupervised]
       [--more-unsupervised MORE_UNSUPERVISED] [--save-dir SAVE_DIR]
       [--num-eval-batches NUM_EVAL_BATCHES] [--rej-mult REJ_MULT]
       [--seed SEED] [--act-fct ACT_FCT] [--c-param C_PARAM]
       [--m-param M_PARAM] [--tau TAU] [--n-mixing-layer N_MIXING_LAYER]
       [--n N] [--space-type {box,sphere,unbounded}] [--m-p M_P] [--c-p C_P]
       [--lr LR] [--p P] [--batch-size BATCH_SIZE] [--n-log-steps N_LOG_STEPS]
       [--n-steps N_STEPS] [--resume-training]

Disentanglement with InfoNCE/Contrastive Learning - MLP Mixing

optional arguments:
  -h, --help            show this help message and exit
  --sphere-r SPHERE_R
  --box-min BOX_MIN     For box normalization only. Minimal value of box.
  --box-max BOX_MAX     For box normalization only. Maximal value of box.
  --sphere-norm         Normalize output to a sphere.
  --box-norm            Normalize output to a box.
  --only-supervised     Only train supervised model.
  --only-unsupervised   Only train unsupervised model.
  --more-unsupervised MORE_UNSUPERVISED
                        How many more steps to do for unsupervised compared to
                        supervised training.
  --save-dir SAVE_DIR
  --num-eval-batches NUM_EVAL_BATCHES
                        Number of batches to average evaluation performance at
                        the end.
  --rej-mult REJ_MULT   Memory/CPU trade-off factor for rejection resampling.
  --seed SEED
  --act-fct ACT_FCT     Activation function in mixing network g.
  --c-param C_PARAM     Concentration parameter of the conditional
                        distribution.
  --m-param M_PARAM     Additional parameter for the marginal (only relevant
                        if it is not uniform).
  --tau TAU
  --n-mixing-layer N_MIXING_LAYER
                        Number of layers in nonlinear mixing network g.
  --n N                 Dimensionality of the latents.
  --space-type {box,sphere,unbounded}
  --m-p M_P             Type of ground-truth marginal distribution. p=0 means
                        uniform; all other p values correspond to (projected)
                        Lp Exponential
  --c-p C_P             Exponent of ground-truth Lp Exponential distribution.
  --lr LR
  --p P                 Exponent of the assumed model Lp Exponential
                        distribution.
  --batch-size BATCH_SIZE
  --n-log-steps N_LOG_STEPS
  --n-steps N_STEPS
  --resume-training

KITTI Masks

>python main_kitti.py --help
usage: main_kitti.py [-h] [--box-norm BOX_NORM] [--p P] [--experiment-dir EXPERIMENT_DIR] [--evaluate] [--specify SPECIFY] [--random-search] [--random-seeds] [--seed SEED] [--beta BETA] [--gamma GAMMA]
                     [--rate-prior RATE_PRIOR] [--data-distribution DATA_DISTRIBUTION] [--rate-data RATE_DATA] [--data-k DATA_K] [--betavae] [--search-beta] [--output-dir OUTPUT_DIR] [--log-dir LOG_DIR]
                     [--ckpt-dir CKPT_DIR] [--max-iter MAX_ITER] [--dataset DATASET] [--batch-size BATCH_SIZE] [--num-workers NUM_WORKERS] [--image-size IMAGE_SIZE] [--use-writer] [--z-dim Z_DIM] [--lr LR]
                     [--beta1 BETA1] [--beta2 BETA2] [--ckpt-name CKPT_NAME] [--log-step LOG_STEP] [--save-step SAVE_STEP] [--kitti-max-delta-t KITTI_MAX_DELTA_T] [--natural-discrete] [--verbose] [--cuda]
                     [--num_runs NUM_RUNS]

Disentanglement with InfoNCE/Contrastive Learning - KITTI Masks

optional arguments:
  -h, --help            show this help message and exit
  --box-norm BOX_NORM
  --p P
  --experiment-dir EXPERIMENT_DIR
                        specify path
  --evaluate            evaluate instead of train
  --specify SPECIFY     use argument to only compute a subset of metrics
  --random-search       whether to random search for params
  --random-seeds        whether to go over random seeds with UDR params
  --seed SEED           random seed
  --beta BETA           weight for kl to normal
  --gamma GAMMA         weight for kl to laplace
  --rate-prior RATE_PRIOR
                        rate (or inverse scale) for prior laplace (larger -> sparser).
  --data-distribution DATA_DISTRIBUTION
                        (laplace, uniform)
  --rate-data RATE_DATA
                        rate (or inverse scale) for data laplace (larger -> sparser). (-1 = rand).
  --data-k DATA_K       k for data uniform (-1 = rand).
  --betavae             whether to do standard betavae training (gamma=0)
  --search-beta         whether to do rand search over beta
  --output-dir OUTPUT_DIR
                        output directory
  --log-dir LOG_DIR     log directory
  --ckpt-dir CKPT_DIR   checkpoint directory
  --max-iter MAX_ITER   maximum training iteration
  --dataset DATASET     dataset name (dsprites, cars3d,smallnorb, shapes3d, mpi3d, kittimasks, natural
  --batch-size BATCH_SIZE
                        batch size
  --num-workers NUM_WORKERS
                        dataloader num_workers
  --image-size IMAGE_SIZE
                        image size. now only (64,64) is supported
  --use-writer          whether to use a log writer
  --z-dim Z_DIM         dimension of the representation z
  --lr LR               learning rate
  --beta1 BETA1         Adam optimizer beta1
  --beta2 BETA2         Adam optimizer beta2
  --ckpt-name CKPT_NAME
                        load previous checkpoint. insert checkpoint filename
  --log-step LOG_STEP   numer of iterations after which data is logged
  --save-step SAVE_STEP
                        number of iterations after which a checkpoint is saved
  --kitti-max-delta-t KITTI_MAX_DELTA_T
                        max t difference between frames sampled from kitti data loader.
  --natural-discrete    discretize natural sprites
  --verbose             for evaluation
  --cuda
  --num_runs NUM_RUNS   when searching over seeds, do 10

3DIdent

>python main_3dident.py --help
usage: main_3dident.py [-h] [--batch-size BATCH_SIZE] [--n-eval-samples N_EVAL_SAMPLES] [--lr LR] [--optimizer {adam,sgd}] [--iterations ITERATIONS]
                                                                   [--n-log-steps N_LOG_STEPS] [--load-model LOAD_MODEL] [--save-model SAVE_MODEL] [--save-every SAVE_EVERY] [--no-cuda] [--position-only]
                                                                   [--rotation-and-color-only] [--rotation-only] [--color-only] [--no-spotlight-position] [--no-spotlight-color] [--no-spotlight]
                                                                   [--non-periodic-rotation-and-color] [--dummy-mixing] [--identity-solution] [--identity-mixing-and-solution]
                                                                   [--approximate-dataset-nn-search] --offline-dataset OFFLINE_DATASET [--faiss-omp-threads FAISS_OMP_THREADS]
                                                                   [--box-constraint {None,fix,learnable}] [--sphere-constraint {None,fix,learnable}] [--workers WORKERS]
                                                                   [--mode {supervised,unsupervised,test}] [--supervised-loss {mse,r2}] [--unsupervised-loss {l1,l2,l3,vmf}]
                                                                   [--non-periodical-conditional {l1,l2,l3}] [--sigma SIGMA] [--encoder {rn18,rn50,rn101,rn151}]

Disentanglement with InfoNCE/Contrastive Learning - 3DIdent

optional arguments:
  -h, --help            show this help message and exit
  --batch-size BATCH_SIZE
  --n-eval-samples N_EVAL_SAMPLES
  --lr LR
  --optimizer {adam,sgd}
  --iterations ITERATIONS
                        How long to train the model
  --n-log-steps N_LOG_STEPS
                        How often to calculate scores and print them
  --load-model LOAD_MODEL
                        Path from where to load the model
  --save-model SAVE_MODEL
                        Path where to save the model
  --save-every SAVE_EVERY
                        After how many steps to save the model (will always be saved at the end)
  --no-cuda
  --position-only
  --rotation-and-color-only
  --rotation-only
  --color-only
  --no-spotlight-position
  --no-spotlight-color
  --no-spotlight
  --non-periodic-rotation-and-color
  --dummy-mixing
  --identity-solution
  --identity-mixing-and-solution
  --approximate-dataset-nn-search
  --offline-dataset OFFLINE_DATASET
  --faiss-omp-threads FAISS_OMP_THREADS
  --box-constraint {None,fix,learnable}
  --sphere-constraint {None,fix,learnable}
  --workers WORKERS     Number of workers to use (0=#cpus)
  --mode {supervised,unsupervised,test}
  --supervised-loss {mse,r2}
  --unsupervised-loss {l1,l2,l3,vmf}
  --non-periodical-conditional {l1,l2,l3}
  --sigma SIGMA         Sigma of the conditional distribution (for vMF: 1/kappa)
  --encoder {rn18,rn50,rn101,rn151}

3DIdent Dataset

We introduce 3Dident, a dataset with hallmarks of natural environments (shadows, different lighting conditions, 3D rotations, etc.). A preliminary version of the dataset is released along with our pre-print.

3DIdent dataset example images

You can access the dataset here. The training and test datasets consists of 250000 and 25000 samples, respectively. To load, you can use the ThreeDIdentDataset class defined in datasets/threedident_dataset.py.

BibTeX

If you find our analysis helpful, please cite our pre-print:

@article{zimmermann2021cl,
  author = {
    Zimmermann, Roland S. and
    Sharma, Yash and
    Schneider, Steffen and
    Bethge, Matthias and
    Brendel, Wieland
  },
  title = {
    Contrastive Learning Inverts the Data Generating Process
  },
  journal = {CoRR},
  volume = {abs/2102.08850},
  year = {2021},
}
[ICLR 2022 Oral] F8Net: Fixed-Point 8-bit Only Multiplication for Network Quantization

F8Net Fixed-Point 8-bit Only Multiplication for Network Quantization (ICLR 2022 Oral) OpenReview | arXiv | PDF | Model Zoo | BibTex PyTorch implementa

Snap Research 76 Dec 13, 2022
Heat transfer problemas solved using python

heat-transfer Heat transfer problems solved using python isolation-convection.py compares the temperature distribution on the problem as shown in the

2 Nov 14, 2021
Learning Energy-Based Models by Diffusion Recovery Likelihood

Learning Energy-Based Models by Diffusion Recovery Likelihood Ruiqi Gao, Yang Song, Ben Poole, Ying Nian Wu, Diederik P. Kingma Paper: https://arxiv.o

Ruiqi Gao 41 Nov 22, 2022
A particular navigation route using satellite feed and can help in toll operations & traffic managemen

How about adding some info that can quanitfy the stress on a particular navigation route using satellite feed and can help in toll operations & traffic management The current analysis is on the satel

Ashish Pandey 1 Feb 14, 2022
Multiple paper open-source codes of the Microsoft Research Asia DKI group

📫 Paper Code Collection (MSRA DKI Group) This repo hosts multiple open-source codes of the Microsoft Research Asia DKI Group. You could find the corr

Microsoft 249 Jan 08, 2023
Official implementation of the ICCV 2021 paper "Joint Inductive and Transductive Learning for Video Object Segmentation"

JOINT This is the official implementation of Joint Inductive and Transductive learning for Video Object Segmentation, to appear in ICCV 2021. @inproce

Yunyao 35 Oct 16, 2022
PyTorch Code for the paper "VSE++: Improving Visual-Semantic Embeddings with Hard Negatives"

Improving Visual-Semantic Embeddings with Hard Negatives Code for the image-caption retrieval methods from VSE++: Improving Visual-Semantic Embeddings

Fartash Faghri 441 Dec 05, 2022
Unit-Convertor - Unit Convertor Built With Python

Python Unit Converter This project can convert Weigth,length and ... units for y

Mahdis Esmaeelian 1 May 31, 2022
Using LSTM write Tang poetry

本教程将通过一个示例对LSTM进行介绍。通过搭建训练LSTM网络,我们将训练一个模型来生成唐诗。本文将对该实现进行详尽的解释,并阐明此模型的工作方式和原因。并不需要过多专业知识,但是可能需要新手花一些时间来理解的模型训练的实际情况。为了节省时间,请尽量选择GPU进行训练。

56 Dec 15, 2022
Efficient and Scalable Physics-Informed Deep Learning and Scientific Machine Learning on top of Tensorflow for multi-worker distributed computing

Notice: Support for Python 3.6 will be dropped in v.0.2.1, please plan accordingly! Efficient and Scalable Physics-Informed Deep Learning Collocation-

tensordiffeq 74 Dec 09, 2022
The official implementation of the research paper "DAG Amendment for Inverse Control of Parametric Shapes"

DAG Amendment for Inverse Control of Parametric Shapes This repository is the official Blender implementation of the paper "DAG Amendment for Inverse

Elie Michel 157 Dec 26, 2022
Towhee is a flexible machine learning framework currently focused on computing deep learning embeddings over unstructured data.

Towhee is a flexible machine learning framework currently focused on computing deep learning embeddings over unstructured data.

1.7k Jan 08, 2023
tf2-keras implement yolov5

YOLOv5 in tesnorflow2.x-keras yolov5数据增强jupyter示例 Bilibili视频讲解地址: 《yolov5 解读,训练,复现》 Bilibili视频讲解PPT文件: yolov5_bilibili_talk_ppt.pdf Bilibili视频讲解PPT文件:

yangcheng 254 Jan 08, 2023
Pixray is an image generation system

Pixray is an image generation system

pixray 883 Jan 07, 2023
Nb workflows - A workflow platform which allows you to run parameterized notebooks programmatically

NB Workflows Description If SQL is a lingua franca for querying data, Jupyter sh

Xavier Petit 6 Aug 18, 2022
Official Implementation for Fast Training of Neural Lumigraph Representations using Meta Learning.

Fast Training of Neural Lumigraph Representations using Meta Learning Project Page | Paper | Data Alexander W. Bergman, Petr Kellnhofer, Gordon Wetzst

Alex 39 Oct 08, 2022
Reinforcement Learning with Q-Learning Algorithm on gym's frozen lake environment implemented in python

Reinforcement Learning with Q Learning Algorithm Q learning algorithm is trained on the gym's frozen lake environment. Libraries Used gym Numpy tqdm P

1 Nov 10, 2021
Kaggle Ultrasound Nerve Segmentation competition [Keras]

Ultrasound nerve segmentation using Keras (1.0.7) Kaggle Ultrasound Nerve Segmentation competition [Keras] #Install (Ubuntu {14,16}, GPU) cuDNN requir

179 Dec 28, 2022
QMagFace: Simple and Accurate Quality-Aware Face Recognition

Quality-Aware Face Recognition 26.11.2021 start readme QMagFace: Simple and Accurate Quality-Aware Face Recognition Research Paper Implementation - To

Philipp Terhörst 59 Jan 04, 2023
Fine-grained Control of Image Caption Generation with Abstract Scene Graphs

Faster R-CNN pretrained on VisualGenome This repository modifies maskrcnn-benchmark for object detection and attribute prediction on VisualGenome data

Shizhe Chen 7 Apr 20, 2021