Cross-Modal Contrastive Learning for Text-to-Image Generation

Overview

Cross-Modal Contrastive Learning for Text-to-Image Generation

This repository hosts the open source JAX implementation of XMC-GAN.

Setup instructions

Environment

Set up virtualenv, and install required libraries:

virtualenv venv
source venv/bin/activate

Add the XMC-GAN library to PYTHONPATH:

export PYTHONPATH=$PYTHONPATH:/home/path/to/xmcgan/root/

JAX Installation

Note: Please follow the official JAX instructions for installing a GPU compatible version of JAX.

Other Dependencies

After installing JAX, install the remaining dependencies with:

pip install -r requirements.txt

Preprocess COCO-2014

To create the training and eval data, first start a directory. By default, the training scripts expect to save results in data/ in the base directory.

mkdir data/

The TFRecords required for training and validation on COCO-2014 can be created by running a preprocessing script over the TFDS coco_captions dataset:

python preprocess_data.py

This may take a while to complete, as it runs a pretrained BERT model over the captions and stores the embeddings. With a GPU, it runs in about 2.5 hours for train, and 1 hour for validation. Once it is done, the train and validation tfrecords files will be saved in the data/ directory. The train files require around 58G of disk space, and the validation requires 29G.

Note: If you run into an error related to TensorFlow gfile, one workaround is to edit site-packages/bert/tokenization.py and change tf.gfile.GFile to tf.io.gfile.GFile. For more details, refer to the following link.

If you run into a tensorflow.python.framework.errors_impl.ResourceExhaustedError about having too many open files, you may have to increase the machine's open file limits. To do so, open the limit configuration file for editing:

vi /etc/security/limits.conf

and append the following lines to the end of the file:

*         hard    nofile      500000
*         soft    nofile      500000
root      hard    nofile      500000
root      soft    nofile      500000

You may have to adjust the limit values depending on your machine. You will need to logout and login to your machine for these values to take effect.

Download Pretrained ResNet

To train XMC-GAN, we need a network pretrained on ImageNet to extract features. For our purposes, we train a ResNet-50 network for this. To download the weights, run:

gsutil cp gs://gresearch/xmcgan/resnet_pretrained.npy data/

If you would like to pretrain your own network on ImageNet, please refer to the official Flax ImageNet example.

Training

Start a training run, by first editing train.sh to specify an appropriate work directory. By default, the script assumes that 8 GPUs are available, and runs training on the first 7 GPUs, while test.sh assumes testing will run on the last GPU. After configuring the training job, start an experiment by running it on bash:

mkdir exp
bash train.sh exp_name &> train.txt

Checkpoints and Tensorboard logs will be saved in /path/to/exp/exp_name. By default, the configs/coco_xmc.py config is used, which runs an experiment for 128px images. This is able to accommodate a batch size of 8 on each GPU, and achieves an FID of around 10.5 - 11.0 with the EMA weights. To reproduce the full results on 256px images in our paper, the full model needs to be run using a 32-core Pod slice of Google Cloud TPU v3 devices.

Evaluation

To run an evaluation job, update test.sh with the correct settings used in the training script. Then, execute

bash test.sh exp_name &> eval.txt

to start an evaluation job. All checkpoints in workdir will be evaluated for FID and Inception Score. If you can spare the GPUs, you can also run train.sh and test.sh in parallel, which will continuously evaluate new checkpoints saved into the work directory. Scores will be written to Tensorboard and output to eval.txt.

Tensorboard

To start a Tensorboard for monitoring training progress, run:

tensorboard --logdir /path/to/exp/exp_name

Citation

If you find this work useful, please consider citing:

@inproceedings{zhang2021cross,
  title={Cross-Modal Contrastive Learning for Text-to-Image Generation},
  author={Zhang, Han and Koh, Jing Yu and Baldridge, Jason and Lee, Honglak and Yang, Yinfei},
  journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2021}
}

Disclaimer

Not an official Google product.

Owner
Google Research
Google Research
Stereo Hybrid Event-Frame (SHEF) Cameras for 3D Perception, IROS 2021

For academic use only. Stereo Hybrid Event-Frame (SHEF) Cameras for 3D Perception Ziwei Wang, Liyuan Pan, Yonhon Ng, Zheyu Zhuang and Robert Mahony Th

Ziwei Wang 11 Jan 04, 2023
Official Implementation of "Transformers Can Do Bayesian Inference"

Official Code for the Paper "Transformers Can Do Bayesian Inference" We train Transformers to do Bayesian Prediction on novel datasets for a large var

AutoML-Freiburg-Hannover 103 Dec 25, 2022
JudeasRx - graphical app for doing personalized causal medicine using the methods invented by Judea Pearl et al.

JudeasRX Instructions Read the references given in the Theory and Notation section below Fire up the Jupyter Notebook judeas-rx.ipynb The notebook dra

Robert R. Tucci 19 Nov 07, 2022
A Haskell kernel for IPython.

IHaskell You can now try IHaskell directly in your browser at CoCalc or mybinder.org. Alternatively, watch a talk and demo showing off IHaskell featur

Andrew Gibiansky 2.4k Dec 29, 2022
Using deep learning model to detect breast cancer.

Breast-Cancer-Detection Breast cancer is the most frequent cancer among women, with around one in every 19 women at risk. The number of cases of breas

1 Feb 13, 2022
PyTorch implementation of our paper: Decoupling and Recoupling Spatiotemporal Representation for RGB-D-based Motion Recognition

Decoupling and Recoupling Spatiotemporal Representation for RGB-D-based Motion Recognition, arxiv This is a PyTorch implementation of our paper. 1. Re

DamoCV 11 Nov 19, 2022
MPViT:Multi-Path Vision Transformer for Dense Prediction

MPViT : Multi-Path Vision Transformer for Dense Prediction This repository inlcu

Youngwan Lee 272 Dec 20, 2022
Deeply Supervised, Layer-wise Prediction-aware (DSLP) Transformer for Non-autoregressive Neural Machine Translation

Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision Training Efficiency We show the training efficiency of our DSLP model b

Chenyang Huang 36 Oct 31, 2022
Rest API Written In Python To Classify NSFW Images.

Rest API Written In Python To Classify NSFW Images.

Wahyusaputra 2 Dec 23, 2021
This is an official PyTorch implementation of Task-Adaptive Neural Network Search with Meta-Contrastive Learning (NeurIPS 2021, Spotlight).

NeurIPS 2021 (Spotlight): Task-Adaptive Neural Network Search with Meta-Contrastive Learning This is an official PyTorch implementation of Task-Adapti

Wonyong Jeong 15 Nov 21, 2022
High performance Cross-platform Inference-engine, you could run Anakin on x86-cpu,arm, nv-gpu, amd-gpu,bitmain and cambricon devices.

Anakin2.0 Welcome to the Anakin GitHub. Anakin is a cross-platform, high-performance inference engine, which is originally developed by Baidu engineer

514 Dec 28, 2022
A light-weight image labelling tool for Python designed for creating segmentation data sets.

An image labelling tool for creating segmentation data sets, for Django and Flask.

117 Nov 21, 2022
Implicit Model Specialization through DAG-based Decentralized Federated Learning

Federated Learning DAG Experiments This repository contains software artifacts to reproduce the experiments presented in the Middleware '21 paper "Imp

Operating Systems and Middleware Group 5 Oct 16, 2022
Deep and online learning with spiking neural networks in Python

Introduction The brain is the perfect place to look for inspiration to develop more efficient neural networks. One of the main differences with modern

Jason Eshraghian 447 Jan 03, 2023
PyTorch implementation of federated learning framework based on the acceleration of global momentum

Federated Learning with Acceleration of Global Momentum PyTorch implementation of federated learning framework based on the acceleration of global mom

0 Dec 23, 2021
GAN-based Matrix Factorization for Recommender Systems

GAN-based Matrix Factorization for Recommender Systems This repository contains the datasets' splits, the source code of the experiments and their res

Ervin Dervishaj 9 Nov 06, 2022
Continuous Diffusion Graph Neural Network

We present Graph Neural Diffusion (GRAND) that approaches deep learning on graphs as a continuous diffusion process and treats Graph Neural Networks (GNNs) as discretisations of an underlying PDE.

Twitter Research 227 Jan 05, 2023
Learnable Motion Coherence for Correspondence Pruning

Learnable Motion Coherence for Correspondence Pruning Yuan Liu, Lingjie Liu, Cheng Lin, Zhen Dong, Wenping Wang Project Page Any questions or discussi

liuyuan 41 Nov 30, 2022
Real-Time Multi-Contact Model Predictive Control via ADMM

Here, you can find the code for the paper 'Real-Time Multi-Contact Model Predictive Control via ADMM'. Code is currently being cleared up and optimize

17 Dec 28, 2022
Codes accompanying the paper "Believe What You See: Implicit Constraint Approach for Offline Multi-Agent Reinforcement Learning" (NeurIPS 2021 Spotlight

Implicit Constraint Q-Learning This is a pytorch implementation of ICQ on Datasets for Deep Data-Driven Reinforcement Learning (D4RL) and ICQ-MA on SM

42 Dec 23, 2022