Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, arXiv 2021

Related tags

Deep Learninghsnet
Overview

PWC PWC PWC PWC PWC PWC PWC PWC

Hypercorrelation Squeeze for Few-Shot Segmentation

This is the implementation of the paper "Hypercorrelation Squeeze for Few-Shot Segmentation" by Juhong Min, Dahyun Kang, and Minsu Cho. Implemented on Python 3.7 and Pytorch 1.5.1.

For more information, check out project [website] and the paper on [arXiv].

Requirements

  • Python 3.7
  • PyTorch 1.5.1
  • cuda 10.1
  • tensorboard 1.14

Conda environment settings:

conda create -n hsnet python=3.7
conda activate hsnet

conda install pytorch=1.5.1 torchvision cudatoolkit=10.1 -c pytorch
conda install -c conda-forge tensorflow
pip install tensorboardX

Preparing Few-Shot Segmentation Datasets

Download following datasets:

1. PASCAL-5i

Download PASCAL VOC2012 devkit (train/val data):

wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar

Download PASCAL VOC2012 SDS extended mask annotations from our [Google Drive].

2. COCO-20i

Download COCO2014 train/val images and annotations:

wget http://images.cocodataset.org/zips/train2014.zip
wget http://images.cocodataset.org/zips/val2014.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip

Download COCO2014 train/val annotations from our Google Drive: [train2014.zip], [val2014.zip]. (and locate both train2014/ and val2014/ under annotations/ directory).

3. FSS-1000

Download FSS-1000 images and annotations from our [Google Drive].

Create a directory '../Datasets_HSN' for the above three few-shot segmentation datasets and appropriately place each dataset to have following directory structure:

../                         # parent directory
├── ./                      # current (project) directory
│   ├── common/             # (dir.) helper functions
│   ├── data/               # (dir.) dataloaders and splits for each FSSS dataset
│   ├── model/              # (dir.) implementation of Hypercorrelation Squeeze Network model 
│   ├── README.md           # intstruction for reproduction
│   ├── train.py            # code for training HSNet
│   └── test.py             # code for testing HSNet
└── Datasets_HSN/
    ├── VOC2012/            # PASCAL VOC2012 devkit
    │   ├── Annotations/
    │   ├── ImageSets/
    │   ├── ...
    │   └── SegmentationClassAug/
    ├── COCO2014/           
    │   ├── annotations/
    │   │   ├── train2014/  # (dir.) training masks (from Google Drive) 
    │   │   ├── val2014/    # (dir.) validation masks (from Google Drive)
    │   │   └── ..some json files..
    │   ├── train2014/
    │   └── val2014/
    └── FSS-1000/           # (dir.) contains 1000 object classes
        ├── abacus/   
        ├── ...
        └── zucchini/

Training

1. PASCAL-5i

python train.py --backbone {vgg16, resnet50, resnet101} 
                --fold {0, 1, 2, 3} 
                --benchmark pascal
                --lr 1e-3
                --bsz 20
                --load "path_to_trained_model/best_model.pt"
                --logpath "your_experiment_name"
  • Training takes approx. 2 days until convergence (trained with four 2080 Ti GPUs).

2. COCO-20i

python train.py --backbone {resnet50, resnet101} 
                --fold {0, 1, 2, 3} 
                --benchmark coco 
                --lr 1e-3
                --bsz 40
                --load "path_to_trained_model/best_model.pt"
                --logpath "your_experiment_name"
  • Training takes approx. 1 week until convergence (trained four Titan RTX GPUs).

3. FSS-1000

python train.py --backbone {vgg16, resnet50, resnet101} 
                --benchmark fss 
                --lr 1e-3
                --bsz 20
                --load "path_to_trained_model/best_model.pt"
                --logpath "your_experiment_name"
  • Training takes approx. 3 days until convergence (trained with four 2080 Ti GPUs).

Babysitting training:

Use tensorboard to babysit training progress:

  • For each experiment, a directory that logs training progress will be automatically generated under logs/ directory.
  • From terminal, run 'tensorboard --logdir logs/' to monitor the training progress.
  • Choose the best model when the validation (mIoU) curve starts to saturate.

Testing

1. PASCAL-5i

Pretrained models with tensorboard logs are available on our [Google Drive].

python test.py --backbone {vgg16, resnet50, resnet101} 
               --fold {0, 1, 2, 3} 
               --benchmark pascal
               --nshot {1, 5} 
               --load "path_to_trained_model/best_model.pt"

2. COCO-20i

Pretrained models with tensorboard logs are available on our [Google Drive].

python test.py --backbone {resnet50, resnet101} 
               --fold {0, 1, 2, 3} 
               --benchmark coco 
               --nshot {1, 5} 
               --load "path_to_trained_model/best_model.pt"

3. FSS-1000

Pretrained models with tensorboard logs are available on our [Google Drive].

python test.py --backbone {vgg16, resnet50, resnet101} 
               --benchmark fss 
               --nshot {1, 5} 
               --load "path_to_trained_model/best_model.pt"

4. Evaluation without support feature masking on PASCAL-5i

  • To reproduce the results in Tab.1 of our main paper, COMMENT OUT line 51 in hsnet.py: support_feats = self.mask_feature(support_feats, support_mask.clone())

Pretrained models with tensorboard logs are available on our [Google Drive].

python test.py --backbone resnet101 
               --fold {0, 1, 2, 3} 
               --benchmark pascal
               --nshot {1, 5} 
               --load "path_to_trained_model/best_model.pt"

Visualization

  • To visualize mask predictions, add command line argument --visualize: (prediction results will be saved under vis/ directory)
  python test.py '...other arguments...' --visualize  

Example qualitative results (1-shot):

BibTeX

If you use this code for your research, please consider citing:

@article{min2021hypercorrelation, 
   title={Hypercorrelation Squeeze for Few-Shot Segmentation},
   author={Juhong Min and Dahyun Kang and Minsu Cho},
   journal={arXiv preprint arXiv:2104.01538},
   year={2021}
}
Owner
Juhong Min
research interest in computer vision
Juhong Min
Code and data for paper "Deep Photo Style Transfer"

deep-photo-styletransfer Code and data for paper "Deep Photo Style Transfer" Disclaimer This software is published for academic and non-commercial use

Fujun Luan 9.9k Dec 29, 2022
Taming Transformers for High-Resolution Image Synthesis

Taming Transformers for High-Resolution Image Synthesis CVPR 2021 (Oral) Taming Transformers for High-Resolution Image Synthesis Patrick Esser*, Robin

CompVis Heidelberg 3.5k Jan 03, 2023
This repository includes the code of the sequence-to-sequence model for discontinuous constituent parsing described in paper Discontinuous Grammar as a Foreign Language.

Discontinuous Grammar as a Foreign Language This repository includes the code of the sequence-to-sequence model for discontinuous constituent parsing

Daniel Fernández-González 2 Apr 07, 2022
Pure python implementation reverse-mode automatic differentiation

MiniGrad A minimal implementation of reverse-mode automatic differentiation (a.k.a. autograd / backpropagation) in pure Python. Inspired by Andrej Kar

Kenny Song 76 Sep 12, 2022
Motion and Shape Capture from Sparse Markers

MoSh++ This repository contains the official chumpy implementation of mocap body solver used for AMASS: AMASS: Archive of Motion Capture as Surface Sh

Nima Ghorbani 135 Dec 23, 2022
Snscrape-jsonl-urls-extractor - Extracts urls from jsonl produced by snscrape

snscrape-jsonl-urls-extractor extracts urls from jsonl produced by snscrape Usag

1 Feb 26, 2022
BirdCLEF 2021 - Birdcall Identification 4th place solution

BirdCLEF 2021 - Birdcall Identification 4th place solution My solution detail kaggle discussion Inference Notebook (best submission) Environment Use K

tattaka 42 Jan 02, 2023
Explainability for Vision Transformers (in PyTorch)

Explainability for Vision Transformers (in PyTorch) This repository implements methods for explainability in Vision Transformers

Jacob Gildenblat 442 Jan 04, 2023
MediaPipe Kullanarak İleri Seviye Bilgisayarla Görü

MediaPipe Kullanarak İleri Seviye Bilgisayarla Görü

Burak Bagatarhan 12 Mar 29, 2022
Script for getting information in discord

User-info.py Script for getting information in https://discord.com/ Instalação: apt-get update -y apt-get upgrade -y apt-get install git pkg install

Moleey 1 Dec 18, 2021
Point cloud processing tool library.

Point Cloud ToolBox This point cloud processing tool library can be used to process point clouds, 3d meshes, and voxels. Environment python 3.7.5 Dep

ZhangXinyun 40 Dec 09, 2022
abess: Fast Best-Subset Selection in Python and R

abess: Fast Best-Subset Selection in Python and R Overview abess (Adaptive BEst Subset Selection) library aims to solve general best subset selection,

297 Dec 21, 2022
Unofficial implementation of HiFi-GAN+ from the paper "Bandwidth Extension is All You Need" by Su, et al.

HiFi-GAN+ This project is an unoffical implementation of the HiFi-GAN+ model for audio bandwidth extension, from the paper Bandwidth Extension is All

Brent M. Spell 134 Dec 30, 2022
MoCoPnet - Deformable 3D Convolution for Video Super-Resolution

MoCoPnet: Exploring Local Motion and Contrast Priors for Infrared Small Target Super-Resolution Pytorch implementation of local motion and contrast pr

Xinyi Ying 28 Dec 15, 2022
Python module providing a framework to trace individual edges in an image using Gaussian process regression.

Edge Tracing using Gaussian Process Regression Repository storing python module which implements a framework to trace individual edges in an image usi

Jamie Burke 7 Dec 27, 2022
The official pytorch implemention of the CVPR paper "Temporal Modulation Network for Controllable Space-Time Video Super-Resolution".

This is the official PyTorch implementation of TMNet in the CVPR 2021 paper "Temporal Modulation Network for Controllable Space-Time VideoSuper-Resolu

Gang Xu 95 Oct 24, 2022
Internship Assessment Task for BaggageAI.

BaggageAI Internship Task Problem Statement: You are given two sets of images:- background and threat objects. Background images are the background x-

Arya Shah 10 Nov 14, 2022
Code repo for "FASA: Feature Augmentation and Sampling Adaptation for Long-Tailed Instance Segmentation" (ICCV 2021)

FASA: Feature Augmentation and Sampling Adaptation for Long-Tailed Instance Segmentation (ICCV 2021) This repository contains the implementation of th

Yuhang Zang 21 Dec 17, 2022
Basics of 2D and 3D Human Pose Estimation.

Human Pose Estimation 101 If you want a slightly more rigorous tutorial and understand the basics of Human Pose Estimation and how the field has evolv

Sudharshan Chandra Babu 293 Dec 14, 2022