A U-Net combined with a variational auto-encoder that is able to learn conditional distributions over semantic segmentations.

Overview

Probabilistic U-Net

+ **Update**
+ An improved Model (the Hierarchical Probabilistic U-Net) + LIDC crops is now available. See below.

Re-implementation of the model described in `A Probabilistic U-Net for Segmentation of Ambiguous Images' (paper @ NeurIPS 2018).

This was also a spotlight presentation at NeurIPS and a short video on the paper of similar content can be found here (4min).

The architecture of the Probabilistic U-Net is depicted below: subfigure a) shows sampling and b) the training setup:

Below see samples conditioned on held-out validation set images from the (stochastic) CityScapes data set:

Setup package in virtual environment

git clone https://github.com/SimonKohl/probabilistic_unet.git .
cd prob_unet/
virtualenv -p python3 venv
source venv/bin/activate
pip3 install -e .

Install batch-generators for data augmentation

cd ..
git clone https://github.com/MIC-DKFZ/batchgenerators
cd batchgenerators
pip3 install nilearn scikit-image nibabel
pip3 install -e .
cd prob_unet

Download & preprocess the Cityscapes dataset

  1. Create a login account on the Cityscapes website: https://www.cityscapes-dataset.com/
  2. Once you've logged in, download the train, val and test annotations and images:
  3. unzip the data (unzip _trainvaltest.zip) and adjust raw_data_dir (full path to unzipped files) and out_dir (full path to desired output directory) in preprocessing_config.py
  4. bilinearly rescale the data to a resolution of 256 x 512 and save as numpy arrays by running
cd cityscapes
python3 preprocessing.py
cd ..

Training

[skip to evaluation in case you only want to use the pretrained model.]
modify data_dir and exp_dir in scripts/prob_unet_config.py then:

cd training
python3 train_prob_unet.py --config prob_unet_config.py

Evaluation

Load your own trained model or use a pretrained model. A set of pretrained weights can be downloaded from zenodo.org (187MB). After down-loading, unpack the file via tar -xvzf pretrained_weights.tar.gz, e.g. in /model. In either case (using your own or the pretrained model), modify the data_dir and exp_dir in evaluation/cityscapes_eval_config.py to match you paths.

then first write samples (defaults to 16 segmentation samples for each of the 500 validation images):

cd ../evaluation
python3 eval_cityscapes.py --write_samples

followed by their evaluation (which is multi-threaded and thus reasonably fast):

python3 eval_cityscapes.py --eval_samples

The evaluation produces a dictionary holding the results. These can be visualized by launching an ipython notbook:

jupyter notebook evaluation_plots.ipynb

The following results are obtained from the pretrained model using above notebook:

Tests

The evaluation metrics are under test-coverage. Run the tests as follows:

cd ../tests/evaluation
python3 -m pytest eval_tests.py

Deviations from original work

The code found in this repository was not used in the original paper and slight modifications apply:

  • training on a single gpu (Titan Xp) instead of distributed training, which is not supported in this implementation
  • average-pooling rather than bilinear interpolation is used for down-sampling operations in the model
  • the number of conv kernels is kept constant after the 3rd scale as opposed to strictly doubling it after each scale (for reduction of memory footprint)
  • HeNormal weight initialization worked better than a orthogonal weight initialization

How to cite this code

Please cite the original publication:

@article{kohl2018probabilistic,
  title={A Probabilistic U-Net for Segmentation of Ambiguous Images},
  author={Kohl, Simon AA and Romera-Paredes, Bernardino and Meyer, Clemens and De Fauw, Jeffrey and Ledsam, Joseph R and Maier-Hein, Klaus H and Eslami, SM and Rezende, Danilo Jimenez and Ronneberger, Olaf},
  journal={arXiv preprint arXiv:1806.05034},
  year={2018}
}

License

The code is published under the Apache License Version 2.0.

Update: The Hierarchical Probabilistic U-Net + LIDC crops

We published an improved model, the Hierarchical Probabilistic U-Net at the Medical Imaging meets Neurips Workshop 2019.

The paper is available from arXiv under A Hierarchical Probabilistic U-Net for Modeling Multi-Scale Ambiguities, May 2019.

The model code is freely available from DeepMind's github repo, see here: code link.

The LIDC data can be downloaded as pngs, cropped to size 180 x 180 from Google Cloud Storage, see here: data link.

A pretrained model can be readily applied to the data using the following Google Colab: Open In Colab.

Owner
Simon Kohl
Simon Kohl
DANet for Tabular data classification/ regression.

Deep Abstract Networks A PyTorch code implemented for the submission DANets: Deep Abstract Networks for Tabular Data Classification and Regression. Do

Ronnie Rocket 55 Sep 14, 2022
Few-NERD: Not Only a Few-shot NER Dataset

Few-NERD: Not Only a Few-shot NER Dataset This is the source code of the ACL-IJCNLP 2021 paper: Few-NERD: A Few-shot Named Entity Recognition Dataset.

THUNLP 319 Dec 30, 2022
Repository features UNet inspired architecture used for segmenting lungs on chest X-Ray images

Lung Segmentation (2D) Repository features UNet inspired architecture used for segmenting lungs on chest X-Ray images. Demo See the application of the

163 Sep 21, 2022
Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction".

TGIN Tensorflow implementation of our method: "Triangle Graph Interest Network for Click-through Rate Prediction". Files in the folder dataset/ electr

Alibaba 21 Dec 21, 2022
Projects of Andfun Yangon

AndFunYangon Projects of Andfun Yangon First Commit We can use gsearch.py to sea

Htin Aung Lu 1 Dec 28, 2021
Binary classification for arrythmia detection with ECG datasets.

HEART DISEASE AI DATATHON 2021 [Eng] / [Kor] #English This is an AI diagnosis modeling contest that uses the heart disease echocardiography and electr

HY_Kim 3 Jul 14, 2022
Full Transformer Framework for Robust Point Cloud Registration with Deep Information Interaction

Full Transformer Framework for Robust Point Cloud Registration with Deep Information Interaction. arxiv This repository contains python scripts for tr

12 Dec 12, 2022
Pytorch implementation of Integrating Tree Path in Transformer for Code Representation

This is an official Pytorch implementation of the approaches proposed in: Han Peng, Ge Li, Wenhan Wang, Yunfei Zhao, Zhi Jin “Integrating Tree Path in

Han Peng 16 Dec 23, 2022
Benchmark library for high-dimensional HPO of black-box models based on Weighted Lasso regression

LassoBench LassoBench is a library for high-dimensional hyperparameter optimization benchmarks based on Weighted Lasso regression. Note: LassoBench is

Kenan Šehić 5 Mar 15, 2022
[PNAS2021] The neural architecture of language: Integrative modeling converges on predictive processing

The neural architecture of language: Integrative modeling converges on predictive processing Code accompanying the paper The neural architecture of la

Martin Schrimpf 36 Dec 01, 2022
Space Time Recurrent Memory Network - Pytorch

Space Time Recurrent Memory Network - Pytorch (wip) Implementation of Space Time Recurrent Memory Network, recurrent network competitive with attentio

Phil Wang 50 Nov 07, 2021
Code for "AutoMTL: A Programming Framework for Automated Multi-Task Learning"

AutoMTL: A Programming Framework for Automated Multi-Task Learning This is the website for our paper "AutoMTL: A Programming Framework for Automated M

Ivy Zhang 40 Dec 04, 2022
Refactoring dalle-pytorch and taming-transformers for TPU VM

Text-to-Image Translation (DALL-E) for TPU in Pytorch Refactoring Taming Transformers and DALLE-pytorch for TPU VM with Pytorch Lightning Requirements

Kim, Taehoon 61 Nov 07, 2022
Leaf: Multiple-Choice Question Generation

Leaf: Multiple-Choice Question Generation Easy to use and understand multiple-choice question generation algorithm using T5 Transformers. The applicat

Kristiyan Vachev 62 Dec 20, 2022
ML-based medical imaging using Azure

Disclaimer This code is provided for research and development use only. This code is not intended for use in clinical decision-making or for any other

Microsoft Azure 68 Dec 23, 2022
NeurIPS 2021 paper 'Representation Learning on Spatial Networks' code

Representation Learning on Spatial Networks This repository is the official implementation of Representation Learning on Spatial Networks. Training Ex

13 Dec 29, 2022
Deep Learning Based Fasion Recommendation System for Ecommerce

Project Name: Fasion Recommendation System for Ecommerce A Deep learning based streamlit web app which can recommened you various types of fasion prod

BAPPY AHMED 13 Dec 13, 2022
Official implementation for “Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior”

HEP Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior Implementation Python3 PyTorch=1.0 NVIDIA GPU+CUDA Training process The

FengZhang 34 Dec 04, 2022
MohammadReza Sharifi 27 Dec 13, 2022
Regularizing Nighttime Weirdness: Efficient Self-supervised Monocular Depth Estimation in the Dark (ICCV 2021)

Regularizing Nighttime Weirdness: Efficient Self-supervised Monocular Depth Estimation in the Dark (ICCV 2021) Kun Wang, Zhenyu Zhang, Zhiqiang Yan, X

kunwang 66 Nov 24, 2022