A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Related tags

MiscellaneousBoMb-OT
Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
Module to align code with thoughts of users and designers. Also magically handles navigation and permissions.

This readme will introduce you to Carteblanche and walk you through an example app, please refer to carteblanche-django-starter for the full example p

Eric Neuman 42 May 28, 2021
Fortnite StW Claimer for Daily Rewards, Research Points and free Llamas.

Fortnite Save the World Daily Reward, Research Points & free Llama Claimer This program allows you to claim Save the World Daily Reward, Research Poin

PRO100KatYT 27 Dec 22, 2022
Retrieve bank transactions and categorize for budgeting use

Budgeting After trying out some budgeting software, I decided to make my own. selenium_scraper Using the selenium package, this script runs an instanc

Marc 1 Nov 10, 2021
Hexa is an advanced browser.It can carry out all the functions present in a browser.

Hexa is an advanced browser.It can carry out all the functions present in a browser.It is coded in the language Python using the modules PyQt5 and sys mainly.It is gonna get developed more in the fut

1 Dec 10, 2021
List of Linux Tools I put on almost every linux / Debian host

Linux-Tools List of Linux Tools I put on almost every Linux / Debian host Installed: geany -- GUI editor/ notepad++ like chkservice -- TUI Linux ser

Stew Alexander 20 Jan 02, 2023
A collection of full-stack resources for programmers.

A collection of full-stack resources for programmers.

Charles-Axel Dein 22.3k Dec 30, 2022
Model Quantization Benchmark

MQBench Update V0.0.2 Fix academic prepare setting. More deployable prepare process. Fix setup.py. Fix deploy on SNPE. Fix convert_deploy bug. Add Qua

500 Jan 06, 2023
Gobigger Explore For Python

Gobigger-Explore 🔮 GoBigger Challenge 2021 Baseline en/中文 🤖 Introduction This is the baseline of GoBigger Multi-Agent Decision Intelligence Challeng

OpenDILab 145 Dec 22, 2022
because rico hates uuid's

terrible-uuid-lambda because rico hates uuid's sub 200ms response times! Try it out here: https://api.mathisvaneetvelde.com/uuid https://api.mathisvan

Mathis Van Eetvelde 2 Feb 15, 2022
Scientific color maps and standardization tools

Scicomap is a package that provides scientific color maps and tools to standardize your favourite color maps if you don't like the built-in ones. Scicomap currently provides sequential, bi-sequential

Thomas Bury 14 Nov 30, 2022
EloGGs 🎮 is a 1v1.LOL Trophy Boosting Program (PATCHED)

EloGGs 🎮 is an old patched 1v1.LOL boosting program I developed months ago, My team made around $1000 total off of this, but now it's been patched by the developers.

doop 1 Jul 22, 2022
Simple script with AminoLab to send ghost messages

Simple script with AminoLab to send ghost messages

Moleey 1 Nov 22, 2021
A clock purely made with python(turtle)...

Clock A clock purely made with python(turtle)... Requirements Pythone3 IDE or any other IDE Installation Clone this repository Running Open this proje

Abhyush 1 Jan 11, 2022
Application to list countries in order of travel from the United States.

Application to list countries in order of travel from the United States.

Broden Wanner 1 Nov 03, 2021
A custom advent of code I am completing

advent-of-code-custom A custom advent of code I am doing in python. The link to the problems I am solving is here: https://github.com/seldoncode/Adven

Rocio PV 2 Dec 11, 2021
Cross-platform config and manager for click console utilities.

climan Help the project financially: Donate: https://smartlegion.github.io/donate/ Yandex Money: https://yoomoney.ru/to/4100115206129186 PayPal: https

3 Aug 31, 2021
An extension module to make reaction based menus with disnake

disnake-ext-menus An experimental extension menu that makes working with reaction menus a bit easier. Installing python -m pip install -U disnake-ext-

1 Nov 25, 2021
A program that makes all 47 textures of Optifine CTM only using 2 textures

A program that makes all 47 textures of Optifine CTM only using 2 textures

1 Jan 22, 2022
An extended version of the hotkeys demo code using action classes

An extended version of the hotkeys application using action classes. In adafruit's Hotkeys code, a macro is using a series of integers, assumed to be

Neradoc 5 May 01, 2022
Tools, guides, and resources for blockchain analysts to interface with data on the Ergo platform.

Ergo Intelligence Objective Provide a suite of easy-to-use toolkits, guides, and resources for blockchain analysts and data scientists to quickly unde

Chris 5 Mar 15, 2022