Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation

Related tags

Deep LearningUTNet
Overview

UTNet (Accepted at MICCAI 2021)

Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation

Introduction

Transformer architecture has emerged to be successful in a number of natural language processing tasks. However, its applications to medical vision remain largely unexplored. In this study, we present UTNet, a simple yet powerful hybrid Transformer architecture that integrates self-attention into a convolutional neural network for enhancing medical image segmentation. UTNet applies self-attention modules in both encoder and decoder for capturing long-range dependency at dif- ferent scales with minimal overhead. To this end, we propose an efficient self-attention mechanism along with relative position encoding that reduces the complexity of self-attention operation significantly from O(n2) to approximate O(n). A new self-attention decoder is also proposed to recover fine-grained details from the skipped connections in the encoder. Our approach addresses the dilemma that Transformer requires huge amounts of data to learn vision inductive bias. Our hybrid layer design allows the initialization of Transformer into convolutional networks without a need of pre-training. We have evaluated UTNet on the multi- label, multi-vendor cardiac magnetic resonance imaging cohort. UTNet demonstrates superior segmentation performance and robustness against the state-of-the-art approaches, holding the promise to generalize well on other medical image segmentations.

image image

Supportting models

UTNet

TransUNet

ResNet50-UTNet

ResNet50-UNet

SwinUNet

To be continue ...

Getting Started

Currently, we only support M&Ms dataset.

Prerequisites

Python >= 3.6
pytorch = 1.8.1
SimpleITK = 2.0.2
numpy = 1.19.5
einops = 0.3.2

Preprocess

Resample all data to spacing of 1.2x1.2 mm in x-y plane. We don't change the spacing of z-axis, as UTNet is a 2D network. Then put all data into 'dataset/'

Training

The M&M dataset provides data from 4 venders, where vendor AB are provided for training while ABCD for testing. The '--domain' is used to control using which vendor for training. '--domain A' for using vender A only. '--domain B' for using vender B only. '--domain AB' for using both vender A and B. For testing, all 4 venders will be used.

UTNet

For default UTNet setting, training with:

python train_deep.py -m UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 1234 --num_blocks 1,1,1,1 --domain AB --gpu 0 --aux_loss

Or you can use '-m UTNet_encoder' to use transformer blocks in the encoder only. This setting is more stable than the default setting in some cases.

To optimize UTNet in your own task, there are several hyperparameters to tune:

'--block_list': indicates apply transformer blocks in which resolution. The number means the number of downsamplings, e.g. 3,4 means apply transformer blocks in features after 3 and 4 times downsampling. Apply transformer blocks in higher resolution feature maps will introduce much more computation.

'--num_blocks': indicates the number of transformer blocks applied in each level. e.g. block_list='3,4', num_blocks=2,4 means apply 2 transformer blocks in 3-times downsampling level and apply 4 transformer blocks in 4-time downsampling level.

'--reduce_size': indicates the size of downsampling for efficient attention. In our experiments, reduce_size 8 and 16 don't have much difference, but 16 will introduce more computation, so we choost 8 as our default setting. 16 might have better performance in other applications.

'--aux_loss': applies deep supervision in training, will introduce some computation overhead but has slightly better performance.

Here are some recomended parameter setting:

--block_list 1234 --num_blocks 1,1,1,1

Our default setting, most efficient setting. Suitable for tasks with limited training data, and most errors occur in the boundary of ROI where high resolution information is important.

--block_list 1234 --num_blocks 1,1,4,8

Similar to the previous one. The model capacity is larger as more transformer blocks are including, but needs larger dataset for training.

--block_list 234 --num_blocks 2,4,8

Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.

Feel free to try other combinations of the hyperparameter like base_chan, reduce_size and num_blocks in each level etc. to trade off between capacity and efficiency to fit your own tasks and datasets.

TransUNet

We borrow code from the original TransUNet repo and fit it into our training framework. If you want to use pre-trained weight, please download from the original repo. The configuration is not parsed by command line, so if you want change the configuration of TransUNet, you need change it inside the train_deep.py.

python train_deep.py -m TransUNet -u EXP_NAME --data_path YOUR_OWN_PATH --gpu 0

ResNet50-UTNet

For fair comparison with TransUNet, we implement the efficient attention proposed in UTNet into ResNet50 backbone, which is basically append transformer blocks into specified level after ResNet blocks. ResNet50-UTNet is slightly better in performance than the default UTNet in M&M dataset.

python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --reduce_size 8 --block_list 123 --num_blocks 1,1,1 --gpu 0

Similar to UTNet, this is the most efficient setting, suitable for tasks with limited training data.

--block_list 23 --num_blocks 2,4

Suitable for tasks that has complex contexts and errors occurs inside ROI. More transformer blocks can help learn higher-level relationship.

ResNet50-UNet

If you don't use Transformer blocks in ResNet50-UTNet, it is actually ResNet50-UNet. So you can use this as the baseline to compare the performance improvement from Transformer for fair comparision with TransUNet and our UTNet.

python train_deep.py -m ResNet_UTNet -u EXP_NAME --data_path YOUR_OWN_PATH --block_list ''  --gpu 0

SwinUNet

Download pre-trained model from the origin repo. As Swin-Transformer's input size is related to window size and is hard to change after pretraining, so we adapt our input size to 224. Without pre-training, SwinUNet's performance is very low.

python train_deep.py -m SwinUNet -u EXP_NAME --data_path YOUR_OWN_PATH --crop_size 224

Citation

If you find this repo helps, please kindly cite our paper, thanks!

@inproceedings{gao2021utnet,
  title={UTNet: a hybrid transformer architecture for medical image segmentation},
  author={Gao, Yunhe and Zhou, Mu and Metaxas, Dimitris N},
  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  pages={61--71},
  year={2021},
  organization={Springer}
}
Some code of the implements of Geological Modeling Using 3D Pixel-Adaptive and Deformable Convolutional Neural Network

3D-GMPDCNN Geological Modeling Using 3D Pixel-Adaptive and Deformable Convolutional Neural Network PyTorch implementation of "Geological Modeling Usin

5 Nov 21, 2022
[ACMMM 2021 Oral] Enhanced Invertible Encoding for Learned Image Compression

InvCompress Official Pytorch Implementation for "Enhanced Invertible Encoding for Learned Image Compression", ACMMM 2021 (Oral) Figure: Our framework

96 Nov 30, 2022
This repo contains code to reproduce all experiments in Equivariant Neural Rendering

Equivariant Neural Rendering This repo contains code to reproduce all experiments in Equivariant Neural Rendering by E. Dupont, M. A. Bautista, A. Col

Apple 83 Nov 16, 2022
Pytorch implementation for RelTransformer

RelTransformer Our Architecture This is a Pytorch implementation for RelTransformer The implementation for Evaluating on VG200 can be found here Requi

Vision CAIR Research Group, KAUST 21 Nov 22, 2022
Robbing the FED: Directly Obtaining Private Data in Federated Learning with Modified Models

Robbing the FED: Directly Obtaining Private Data in Federated Learning with Modified Models This repo contains a barebones implementation for the atta

16 Dec 04, 2022
3rd Place Solution for ICCV 2021 Workshop SSLAD Track 3A - Continual Learning Classification Challenge

Online Continual Learning via Multiple Deep Metric Learning and Uncertainty-guided Episodic Memory Replay 3rd Place Solution for ICCV 2021 Workshop SS

Rifki Kurniawan 6 Nov 10, 2022
Noise Conditional Score Networks (NeurIPS 2019, Oral)

Generative Modeling by Estimating Gradients of the Data Distribution This repo contains the official implementation for the NeurIPS 2019 paper Generat

451 Dec 26, 2022
A PyTorch implementation of SlowFast based on ICCV 2019 paper "SlowFast Networks for Video Recognition"

SlowFast A PyTorch implementation of SlowFast based on ICCV 2019 paper SlowFast Networks for Video Recognition. Requirements Anaconda PyTorch conda in

Hao Ren 8 Dec 23, 2022
Frequency Spectrum Augmentation Consistency for Domain Adaptive Object Detection

Frequency Spectrum Augmentation Consistency for Domain Adaptive Object Detection Main requirements torch = 1.0 torchvision = 0.2.0 Python 3 Environm

15 Apr 04, 2022
The repository for the paper "When Do You Need Billions of Words of Pretraining Data?"

pretraining-learning-curves This is the repository for the paper When Do You Need Billions of Words of Pretraining Data? Edge Probing We use jiant1 fo

ML² AT CILVR 19 Nov 25, 2022
PyTorch implementation of Algorithm 1 of "On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models"

Code for On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models This repository will reproduce the main results from our pape

Mitch Hill 32 Nov 25, 2022
Web mining module for Python, with tools for scraping, natural language processing, machine learning, network analysis and visualization.

Pattern Pattern is a web mining module for Python. It has tools for: Data Mining: web services (Google, Twitter, Wikipedia), web crawler, HTML DOM par

Computational Linguistics Research Group 8.4k Jan 03, 2023
Code for Boundary-Aware Segmentation Network for Mobile and Web Applications

BASNet Boundary-Aware Segmentation Network for Mobile and Web Applications This repository contain implementation of BASNet in tensorflow/keras. comme

Hamid Ali 8 Nov 24, 2022
Deep Learning for Computer Vision final project

Deep Learning for Computer Vision final project

grassking100 1 Nov 30, 2021
Image-generation-baseline - MUGE Text To Image Generation Baseline

MUGE Text To Image Generation Baseline Requirements and Installation More detail

23 Oct 17, 2022
A simple image/video to Desmos graph converter run locally

Desmos Bezier Renderer A simple image/video to Desmos graph converter run locally Sample Result Setup Install dependencies apt update apt install git

Kevin JY Cui 339 Dec 23, 2022
A Deep Learning Based Knowledge Extraction Toolkit for Knowledge Base Population

DeepKE is a knowledge extraction toolkit supporting low-resource and document-level scenarios for entity, relation and attribute extraction. We provide comprehensive documents, Google Colab tutorials

ZJUNLP 1.6k Jan 05, 2023
Hierarchical Motion Encoder-Decoder Network for Trajectory Forecasting (HMNet)

Hierarchical Motion Encoder-Decoder Network for Trajectory Forecasting (HMNet) Our paper: https://arxiv.org/abs/2111.13324 We will release the complet

15 Oct 17, 2022
《Lerning n Intrinsic Grment Spce for Interctive Authoring of Grment Animtion》

Learning an Intrinsic Garment Space for Interactive Authoring of Garment Animation Overview This is the demo code for training a motion invariant enco

YuanBo 213 Dec 14, 2022
[NeurIPS 2021] Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods

Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods Large Scale Learning on Non-Homophilous Graphs: New Benchmark

60 Jan 03, 2023