Code for NeurIPS 2020 article "Contrastive learning of global and local features for medical image segmentation with limited annotations"

Overview

Contrastive learning of global and local features for medical image segmentation with limited annotations

The code is for the article "Contrastive learning of global and local features for medical image segmentation with limited annotations" which got accepted as an Oral presentation at NeurIPS 2020 (33rd international conference on Neural Information Processing Systems). With the proposed pre-training method using Contrastive learning, we get competitive segmentation performance with just 2 labeled training volumes compared to a benchmark that is trained with many labeled volumes.
https://arxiv.org/abs/2006.10511

Observations / Conclusions:

  1. For medical image segmentation, the proposed contrastive pre-training strategy incorporating domain knowledge present naturally across medical volumes yields better performance than baseline, other pre-training methods, semi-supervised, and data augmentation methods.
  2. Proposed local contrastive loss, an extension of global loss, provides an additional boost in performance by learning distinctive local-level representation to distinguish between neighbouring regions.
  3. The proposed pre-training strategy is complementary to semi-supervised and data augmentation methods. Combining them yields a further boost in accuracy.

Authors:
Krishna Chaitanya (email),
Ertunc Erdil,
Neerav Karani,
Ender Konukoglu.

Requirements:
Python 3.6.1,
Tensorflow 1.12.0,
rest of the requirements are mentioned in the "requirements.txt" file.

I) To clone the git repository.
git clone https://github.com/krishnabits001/domain_specific_dl.git

II) Install python, required packages and tensorflow.
Then, install python packages required using below command or the packages mentioned in the file.
pip install -r requirements.txt

To install tensorflow
pip install tensorflow-gpu=1.12.0

III) Dataset download.
To download the ACDC Cardiac dataset, check the website :
https://www.creatis.insa-lyon.fr/Challenge/acdc.

To download the Medical Decathlon Prostate dataset, check the website :
http://medicaldecathlon.com/

To download the MMWHS Cardiac dataset, check the website :
http://www.sdspeople.fudan.edu.cn/zhuangxiahai/0/mmwhs/

All the images were bias corrected using N4 algorithm with a threshold value of 0.001. For more details, refer to the "N4_bias_correction.py" file in scripts.
Image and label pairs are re-sampled (to chosen target resolution) and cropped/zero-padded to a fixed size using "create_cropped_imgs.py" file.

IV) Train the models.
Below commands are an example for ACDC dataset.
The models need to be trained sequentially as follows (check "train_model/pretrain_and_fine_tune_script.sh" script for commands)
Steps :

  1. Step 1: To pre-train the encoder with global loss by incorporating proposed domain knowledge when defining positive and negative pairs.
    cd train_model/
    python pretr_encoder_global_contrastive_loss.py --dataset=acdc --no_of_tr_imgs=tr52 --global_loss_exp_no=2 --n_parts=4 --temp_fac=0.1 --bt_size=12

  2. Step 2: After step 1, we pre-train the decoder with proposed local loss to aid segmentation task by learning distinctive local-level representations.
    python pretr_decoder_local_contrastive_loss.py --dataset=acdc --no_of_tr_imgs=tr52 --pretr_no_of_tr_imgs=tr52 --local_reg_size=1 --no_of_local_regions=13 --temp_fac=0.1 --global_loss_exp_no=2 --local_loss_exp_no=0 --no_of_decoder_blocks=3 --no_of_neg_local_regions=5 --bt_size=12

  3. Step 3: We use the pre-trained encoder and decoder weights as initialization and fine-tune to segmentation task using limited annotations.
    python ft_pretr_encoder_decoder_net_local_loss.py --dataset=acdc --pretr_no_of_tr_imgs=tr52 --local_reg_size=1 --no_of_local_regions=13 --temp_fac=0.1 --global_loss_exp_no=2 --local_loss_exp_no=0 --no_of_decoder_blocks=3 --no_of_neg_local_regions=5 --no_of_tr_imgs=tr1 --comb_tr_imgs=c1 --ver=0

To train the baseline with affine and random deformations & intensity transformations for comparison, use the below code file.
cd train_model/
python tr_baseline.py --dataset=acdc --no_of_tr_imgs=tr1 --comb_tr_imgs=c1 --ver=0

V) Config files contents.
One can modify the contents of the below 2 config files to run the required experiments.
experiment_init directory contains 2 files.
Example for ACDC dataset:

  1. init_acdc.py
    --> contains the config details like target resolution, image dimensions, data path where the dataset is stored and path to save the trained models.
  2. data_cfg_acdc.py
    --> contains an example of data config details where one can set the patient ids which they want to use as train, validation and test images.

Bibtex citation:

@article{chaitanya2020contrastive,
  title={Contrastive learning of global and local features for medical image segmentation with limited annotations},
  author={Chaitanya, Krishna and Erdil, Ertunc and Karani, Neerav and Konukoglu, Ender},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  year={2020}
}
Owner
Krishna Chaitanya
Doctoral Student, ETH Zurich
Krishna Chaitanya
TJU Deep Learning & Neural Network

Deep_Learning & Neural_Network_Lab 实验环境 Python 3.9 Anaconda3(官网下载或清华镜像都行) PyTorch 1.10.1(安装代码如下) conda install pytorch torchvision torchaudio cudatool

St3ve Lee 1 Jan 19, 2022
Exploring Simple 3D Multi-Object Tracking for Autonomous Driving (ICCV 2021)

Exploring Simple 3D Multi-Object Tracking for Autonomous Driving Chenxu Luo, Xiaodong Yang, Alan Yuille Exploring Simple 3D Multi-Object Tracking for

QCraft 141 Nov 21, 2022
PoseCamera is python based SDK for human pose estimation through RGB webcam.

PoseCamera PoseCamera is python based SDK for human pose estimation through RGB webcam. Install install posecamera package through pip pip install pos

WonderTree 7 Jul 20, 2021
A Jinja extension (compatible with Flask and other frameworks) to compile and/or compress your assets.

A Jinja extension (compatible with Flask and other frameworks) to compile and/or compress your assets.

Jayson Reis 94 Nov 21, 2022
Towards Implicit Text-Guided 3D Shape Generation (CVPR2022)

Towards Implicit Text-Guided 3D Shape Generation Towards Implicit Text-Guided 3D Shape Generation (CVPR2022) Code for the paper [Towards Implicit Text

55 Dec 16, 2022
Demonstration of the Model Training as a CI/CD System in Vertex AI

Model Training as a CI/CD System This project demonstrates the machine model training as a CI/CD system in GCP platform. You will see more detailed wo

Chansung Park 19 Dec 28, 2022
Composing methods for ML training efficiency

MosaicML Composer contains a library of methods, and ways to compose them together for more efficient ML training.

MosaicML 2.8k Jan 08, 2023
code for our ECCV-2020 paper: Self-supervised Video Representation Learning by Pace Prediction

Video_Pace This repository contains the code for the following paper: Jiangliu Wang, Jianbo Jiao and Yunhui Liu, "Self-Supervised Video Representation

Jiangliu Wang 95 Dec 14, 2022
Deep Learning as a Cloud API Service.

Deep API Deep Learning as Cloud APIs. This project provides pre-trained deep learning models as a cloud API service. A web interface is available as w

Wu Han 4 Jan 06, 2023
This is the code of "Multi-view Contrastive Graph Clustering" in NeurlPS 2021.

MCGC Description This is the code of "Multi-view Contrastive Graph Clustering" in NeurlPS 2021. Datasets Results ACM DBLP IMDB Amazon photos Amazon co

31 Nov 14, 2022
Calling Julia from Python - an experiment on data loading

Calling Julia from Python - an experiment on data loading See the slides. TLDR After reading Patrick's blog post, we decided to try to replace C++ wit

Abel Siqueira 8 Jun 07, 2022
Official implementation of "Membership Inference Attacks Against Self-supervised Speech Models"

Introduction Official implementation of "Membership Inference Attacks Against Self-supervised Speech Models". In this work, we demonstrate that existi

Wei-Cheng Tseng 7 Nov 01, 2022
Semantic Edge Detection with Diverse Deep Supervision

Semantic Edge Detection with Diverse Deep Supervision This repository contains the code for our IJCV paper: "Semantic Edge Detection with Diverse Deep

Yun Liu 12 Dec 31, 2022
Efficient Deep Learning Systems course

Efficient Deep Learning Systems This repository contains materials for the Efficient Deep Learning Systems course taught at the Faculty of Computer Sc

Max Ryabinin 173 Dec 29, 2022
Face2webtoon - Despite its importance, there are few previous works applying I2I translation to webtoon.

Despite its importance, there are few previous works applying I2I translation to webtoon. I collected dataset from naver webtoon 연애혁명 and tried to transfer human faces to webtoon domain.

이상윤 64 Oct 19, 2022
[CVPR 2022] Official code for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration"

MDCA Calibration This is the official PyTorch implementation for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved

MDCA Calibration 21 Dec 22, 2022
Official code for 'Robust Siamese Object Tracking for Unmanned Aerial Manipulator' and offical introduction to UAMT100 benchmark

SiamSA: Robust Siamese Object Tracking for Unmanned Aerial Manipulator Demo video 📹 Our video on Youtube and bilibili demonstrates the evaluation of

Intelligent Vision for Robotics in Complex Environment 12 Dec 18, 2022
PyTorch implementation of GLOM

GLOM PyTorch implementation of GLOM, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attent

Yeonwoo Sung 20 Aug 17, 2022
Source codes for Improved Few-Shot Visual Classification (CVPR 2020), Enhancing Few-Shot Image Classification with Unlabelled Examples

Source codes for Improved Few-Shot Visual Classification (CVPR 2020), Enhancing Few-Shot Image Classification with Unlabelled Examples (WACV 2022) and Beyond Simple Meta-Learning: Multi-Purpose Model

PLAI Group at UBC 42 Dec 06, 2022
Pytorch implementation for Semantic Segmentation/Scene Parsing on MIT ADE20K dataset

Semantic Segmentation on MIT ADE20K dataset in PyTorch This is a PyTorch implementation of semantic segmentation models on MIT ADE20K scene parsing da

MIT CSAIL Computer Vision 4.5k Jan 08, 2023