A U-Net combined with a variational auto-encoder that is able to learn conditional distributions over semantic segmentations.

Overview

Probabilistic U-Net

+ **Update**
+ An improved Model (the Hierarchical Probabilistic U-Net) + LIDC crops is now available. See below.

Re-implementation of the model described in `A Probabilistic U-Net for Segmentation of Ambiguous Images' (paper @ NeurIPS 2018).

This was also a spotlight presentation at NeurIPS and a short video on the paper of similar content can be found here (4min).

The architecture of the Probabilistic U-Net is depicted below: subfigure a) shows sampling and b) the training setup:

Below see samples conditioned on held-out validation set images from the (stochastic) CityScapes data set:

Setup package in virtual environment

git clone https://github.com/SimonKohl/probabilistic_unet.git .
cd prob_unet/
virtualenv -p python3 venv
source venv/bin/activate
pip3 install -e .

Install batch-generators for data augmentation

cd ..
git clone https://github.com/MIC-DKFZ/batchgenerators
cd batchgenerators
pip3 install nilearn scikit-image nibabel
pip3 install -e .
cd prob_unet

Download & preprocess the Cityscapes dataset

  1. Create a login account on the Cityscapes website: https://www.cityscapes-dataset.com/
  2. Once you've logged in, download the train, val and test annotations and images:
  3. unzip the data (unzip _trainvaltest.zip) and adjust raw_data_dir (full path to unzipped files) and out_dir (full path to desired output directory) in preprocessing_config.py
  4. bilinearly rescale the data to a resolution of 256 x 512 and save as numpy arrays by running
cd cityscapes
python3 preprocessing.py
cd ..

Training

[skip to evaluation in case you only want to use the pretrained model.]
modify data_dir and exp_dir in scripts/prob_unet_config.py then:

cd training
python3 train_prob_unet.py --config prob_unet_config.py

Evaluation

Load your own trained model or use a pretrained model. A set of pretrained weights can be downloaded from zenodo.org (187MB). After down-loading, unpack the file via tar -xvzf pretrained_weights.tar.gz, e.g. in /model. In either case (using your own or the pretrained model), modify the data_dir and exp_dir in evaluation/cityscapes_eval_config.py to match you paths.

then first write samples (defaults to 16 segmentation samples for each of the 500 validation images):

cd ../evaluation
python3 eval_cityscapes.py --write_samples

followed by their evaluation (which is multi-threaded and thus reasonably fast):

python3 eval_cityscapes.py --eval_samples

The evaluation produces a dictionary holding the results. These can be visualized by launching an ipython notbook:

jupyter notebook evaluation_plots.ipynb

The following results are obtained from the pretrained model using above notebook:

Tests

The evaluation metrics are under test-coverage. Run the tests as follows:

cd ../tests/evaluation
python3 -m pytest eval_tests.py

Deviations from original work

The code found in this repository was not used in the original paper and slight modifications apply:

  • training on a single gpu (Titan Xp) instead of distributed training, which is not supported in this implementation
  • average-pooling rather than bilinear interpolation is used for down-sampling operations in the model
  • the number of conv kernels is kept constant after the 3rd scale as opposed to strictly doubling it after each scale (for reduction of memory footprint)
  • HeNormal weight initialization worked better than a orthogonal weight initialization

How to cite this code

Please cite the original publication:

@article{kohl2018probabilistic,
  title={A Probabilistic U-Net for Segmentation of Ambiguous Images},
  author={Kohl, Simon AA and Romera-Paredes, Bernardino and Meyer, Clemens and De Fauw, Jeffrey and Ledsam, Joseph R and Maier-Hein, Klaus H and Eslami, SM and Rezende, Danilo Jimenez and Ronneberger, Olaf},
  journal={arXiv preprint arXiv:1806.05034},
  year={2018}
}

License

The code is published under the Apache License Version 2.0.

Update: The Hierarchical Probabilistic U-Net + LIDC crops

We published an improved model, the Hierarchical Probabilistic U-Net at the Medical Imaging meets Neurips Workshop 2019.

The paper is available from arXiv under A Hierarchical Probabilistic U-Net for Modeling Multi-Scale Ambiguities, May 2019.

The model code is freely available from DeepMind's github repo, see here: code link.

The LIDC data can be downloaded as pngs, cropped to size 180 x 180 from Google Cloud Storage, see here: data link.

A pretrained model can be readily applied to the data using the following Google Colab: Open In Colab.

Owner
Simon Kohl
Simon Kohl
This is the official implementation of TrivialAugment and a mini-library for the application of multiple image augmentation strategies including RandAugment and TrivialAugment.

Trivial Augment This is the official implementation of TrivialAugment (https://arxiv.org/abs/2103.10158), as was used for the paper. TrivialAugment is

AutoML-Freiburg-Hannover 94 Dec 30, 2022
This repository contains all source code, pre-trained models related to the paper "An Empirical Study on GANs with Margin Cosine Loss and Relativistic Discriminator"

An Empirical Study on GANs with Margin Cosine Loss and Relativistic Discriminator This is a Pytorch implementation for the paper "An Empirical Study o

Cuong Nguyen 3 Nov 15, 2021
The Turing Change Point Detection Benchmark: An Extensive Benchmark Evaluation of Change Point Detection Algorithms on real-world data

Turing Change Point Detection Benchmark Welcome to the repository for the Turing Change Point Detection Benchmark, a benchmark evaluation of change po

The Alan Turing Institute 85 Dec 28, 2022
Deep motion transfer

animation-with-keypoint-mask Paper The right most square is the final result. Softmax mask (circles): \ Heatmap mask: \ conda env create -f environmen

9 Nov 01, 2022
基于Paddlepaddle复现yolov5,支持PaddleDetection接口

PaddleDetection yolov5 https://github.com/Sharpiless/PaddleDetection-Yolov5 简介 PaddleDetection飞桨目标检测开发套件,旨在帮助开发者更快更好地完成检测模型的组建、训练、优化及部署等全开发流程。 PaddleD

36 Jan 07, 2023
Turning pixels into virtual points for multimodal 3D object detection.

Multimodal Virtual Point 3D Detection Turning pixels into virtual points for multimodal 3D object detection. Multimodal Virtual Point 3D Detection, Ti

Tianwei Yin 204 Jan 08, 2023
A basic implementation of Layer-wise Relevance Propagation (LRP) in PyTorch.

Layer-wise Relevance Propagation (LRP) in PyTorch Basic unsupervised implementation of Layer-wise Relevance Propagation (Bach et al., Montavon et al.)

Kai Fabi 28 Dec 26, 2022
A project studying the influence of communication in multi-objective normal-form games

Communication in Multi-Objective Normal-Form Games This repo consists of five different types of agents that we have used in our study of communicatio

Willem Röpke 0 Dec 17, 2021
GNN4Traffic - This is the repository for the collection of Graph Neural Network for Traffic Forecasting

GNN4Traffic - This is the repository for the collection of Graph Neural Network for Traffic Forecasting

564 Jan 02, 2023
Self-Supervised Learning of Event-based Optical Flow with Spiking Neural Networks

Self-Supervised Learning of Event-based Optical Flow with Spiking Neural Networks Work accepted at NeurIPS'21 [paper, video]. If you use this code in

TU Delft 43 Dec 07, 2022
PyTorch implementation of "MLP-Mixer: An all-MLP Architecture for Vision" Tolstikhin et al. (2021)

mlp-mixer-pytorch PyTorch implementation of "MLP-Mixer: An all-MLP Architecture for Vision" Tolstikhin et al. (2021) Usage import torch from mlp_mixer

isaac 27 Jul 09, 2022
Discord Multi Tool that focuses on design and easy usage

Multi-Tool-v1.0 Discord Multi Tool that focuses on design and easy usage Delete webhook Block all friends Spam webhook Modify webhook Webhook info Tok

Lodi#0001 24 May 23, 2022
CO-PILOT: COllaborative Planning and reInforcement Learning On sub-Task curriculum

CO-PILOT CO-PILOT: COllaborative Planning and reInforcement Learning On sub-Task curriculum, NeurIPS 2021, Shuang Ao, Tianyi Zhou, Guodong Long, Qingh

Shuang Ao 1 Feb 18, 2022
This repo contains the code and data used in the paper "Wizard of Search Engine: Access to Information Through Conversations with Search Engines"

Wizard of Search Engine: Access to Information Through Conversations with Search Engines by Pengjie Ren, Zhongkun Liu, Xiaomeng Song, Hongtao Tian, Zh

19 Oct 27, 2022
PyTorch implementation of PSPNet segmentation network

pspnet-pytorch PyTorch implementation of PSPNet segmentation network Original paper Pyramid Scene Parsing Network Details This is a slightly different

Roman Trusov 532 Dec 29, 2022
A wrapper around SageMaker ML Lineage Tracking extending ML Lineage to end-to-end ML lifecycles, including additional capabilities around Feature Store groups, queries, and other relevant artifacts.

ML Lineage Helper This library is a wrapper around the SageMaker SDK to support ease of lineage tracking across the ML lifecycle. Lineage artifacts in

AWS Samples 12 Nov 01, 2022
details on efforts to dump the Watermelon Games Paprium cart

Reminder, if you like these repos, fork them so they don't disappear https://github.com/ArcadeHustle/WatermelonPapriumDump/fork Big thanks to Fonzie f

Hustle Arcade 29 Dec 11, 2022
Tweesent-back - Tweesent backend uses fastAPI as the web framework

TweeSent Backend Tweesent backend. This repo uses fastAPI as the web framework.

0 Mar 26, 2022
[CVPR 2021] Unsupervised 3D Shape Completion through GAN Inversion

ShapeInversion Paper Junzhe Zhang, Xinyi Chen, Zhongang Cai, Liang Pan, Haiyu Zhao, Shuai Yi, Chai Kiat Yeo, Bo Dai, Chen Change Loy "Unsupervised 3D

100 Dec 22, 2022
Chainer Implementation of Semantic Segmentation using Adversarial Networks

Semantic Segmentation using Adversarial Networks Requirements Chainer (1.23.0) Differences Use of FCN-VGG16 instead of Dilated8 as Segmentor. Caution

Taiki Oyama 99 Jun 28, 2022