Deep Image Matting implementation in PyTorch

Overview

Deep Image Matting

Deep Image Matting paper implementation in PyTorch.

Differences

  1. "fc6" is dropped.
  2. Indices pooling.

"fc6" is clumpy, over 100 millions parameters, makes the model hard to converge. I guess it is the reason why the model (paper) has to be trained stagewisely.

Performance

  • The Composition-1k testing dataset.
  • Evaluate with whole image.
  • SAD normalized by 1000.
  • Input image is normalized with mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
  • Both erode and dialte to generate trimap.
Models SAD MSE Download
paper-stage0 59.6 0.019
paper-stage1 54.6 0.017
paper-stage3 50.4 0.014
my-stage0 66.8 0.024 Link

Dependencies

  • Python 3.5.2
  • PyTorch 1.1.0

Dataset

Adobe Deep Image Matting Dataset

Follow the instruction to contact author for the dataset.

MSCOCO

Go to MSCOCO to download:

PASCAL VOC

Go to PASCAL VOC to download:

Usage

Data Pre-processing

Extract training images:

$ python pre_process.py

Train

$ python train.py

If you want to visualize during training, run in your terminal:

$ tensorboard --logdir runs

Experimental results

The Composition-1k testing dataset

  1. Test:
$ python test.py

It prints out average SAD and MSE errors when finished.

The alphamatting.com dataset

  1. Download the evaluation datasets: Go to the Datasets page and download the evaluation datasets. Make sure you pick the low-resolution dataset.

  2. Extract evaluation images:

$ python extract.py
  1. Evaluate:
$ python eval.py

Click to view whole images:

Image Trimap1 Trimap2 Trimap3
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image
image image image image

Demo

Download pre-trained Deep Image Matting Link then run:

$ python demo.py
Image/Trimap Output/GT New BG/Compose
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image
image image image

小小的赞助~

Sample

若对您有帮助可给予小小的赞助~




Comments
  • the frozen model named BEST_checkpoint.tar cannot be uncompressed

    the frozen model named BEST_checkpoint.tar cannot be uncompressed

    when I try to uncompress the frozen model it shows

    tar: This does not look like a tar archive tar: Skipping to next header tar: Exiting with failure status due to previous errors

    this means the .tar file is not complete

    opened by banrenmasanxing 6
  • my own datasets are all full human body images

    my own datasets are all full human body images

    Hi,thanks for your excellent work.Now i prepare my own datasets.This datasets are consists of thounds of high resolution image(average 4000*4000).They are all full human body images.When i process these images,i meet a questions: When i crop the trimap(generated from alpha),often crop some places which are not include hair.Such as foot,leg.Is it ok to input these images into [email protected]

    opened by lfxx 5
  • run demo.py question!

    run demo.py question!

    File "demo.py", line 84, in new_bgs = random.sample(new_bgs, 10) File "C:\Users\15432\AppData\Local\conda\conda\envs\python34\lib\random.py", line 324, in sample raise ValueError("Sample larger than population") ValueError: Sample larger than population

    opened by kxcg99 5
  • Invalid BEST_checkpoint.tar ?

    Invalid BEST_checkpoint.tar ?

    Hi, thank you for the code. I tried to download the pretrained model and extract it but it dosnt work.

    tar xvf BEST_checkpoint.tar BEST_checkpoint
    

    results in

    tar: Ceci ne ressemble pas à une archive de type « tar »
    tar: On saute à l'en-tête suivant
    tar: BEST_checkpoint : non trouvé dans l'archive
    tar: Arrêt avec code d'échec à cause des erreurs précédentes
    

    anything i'm doing the wrong way ? or the provided tar is not valid ? kind reards

    opened by flocreate 4
  • How can i get the Trimaps of my pictures?

    How can i get the Trimaps of my pictures?

    Now, I got a model, I want to use it but I can't, because I have not the Trimaps of my pictures. Are there the script of code to build the Trimaps? How can i get the Trimaps of my pictures?

    opened by huangjunxiong11 3
  • can not unpack the 'BEST_checkpoint.tar'

    can not unpack the 'BEST_checkpoint.tar'

    When i download the file "BEST_checkpoint.tar" successfully, i can't unpack it. Actually, when i try to unpack 'BEST_checkpoint.tar', it make an error. Is it my fault , or, Is the file mistaken?

    opened by huangjunxiong11 3
  • Demo error

    Demo error

    /Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py:435: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) Traceback (most recent call last): File "demo.py", line 69, in checkpoint = torch.load(checkpoint) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 368, in load return _load(f, map_location, pickle_module) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 542, in _load result = unpickler.load() File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 505, in persistent_load data_type(size), location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 114, in default_restore_location result = fn(storage, location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 95, in _cuda_deserialize device = validate_cuda_device(location) File "/Users/7plus/opt/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 79, in validate_cuda_device raise RuntimeError('Attempting to deserialize object on a CUDA ' RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.

    opened by Mlt123 3
  • Deep-Image-Matting-v2 implemetation on Android

    Deep-Image-Matting-v2 implemetation on Android

    Hi, Thanks for you work! its looking awesome output. I want to integrate your demo into android project. Is it possible to integrate model into android Project? If it possible, then How can i integrate this model into android project? Can you please give some suggestions? Thanks in advance.

    opened by charlizesmith 3
  • unable to start training using pretrained weigths

    unable to start training using pretrained weigths

    whenever pre-trained weights are used for training the model using own dataset, the following error is occurring.

    python3 train.py --batch-size 4 --checkpoint checkpoint/BEST_checkpoint.tar

    /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) /usr/local/lib/python3.5/dist-packages/torch/serialization.py:454: SourceChangeWarning: source code of class 'torch.nn.modules.activation.ReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set torch.nn.Module.dump_patches = True and use the patch tool to revert the changes. warnings.warn(msg, SourceChangeWarning) Traceback (most recent call last): File "train.py", line 180, in main() File "train.py", line 176, in main train_net(args) File "train.py", line 71, in train_net logger=logger) File "train.py", line 112, in train alpha_out = model(img) # [N, 3, 320, 320] File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/usr/local/lib/python3.5/dist-packages/torch/nn/parallel/data_parallel.py", line 143, in forward if t.device != self.src_device_obj: File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 539, in getattr type(self).name, name)) AttributeError: 'DataParallel' object has no attribute 'src_device_obj'

    opened by dev-srikanth 3
  • v2 didn't performance well as v1?

    v2 didn't performance well as v1?

    Hi, thanks for your pretrained model! I test both your v1 pretrained model and v2 pretrained model , v2 is much faster than v1 , but I found it didn't performance well as v1. the image: WechatIMG226 the origin tri map: test7_tri the v1 output: WechatIMG225 the v2 output: test7_result

    do you know what's the problem?

    Thanks,

    opened by MarSaKi 3
  • Questions about the PyTorch version and an issue in training regarding to the batch size

    Questions about the PyTorch version and an issue in training regarding to the batch size

    Hi,

    Thank you for sharing your PyTorch version of reimplementation. Would you like to share the PyTorch version you used to development?

    I am using PyTorch 1.0.1, CUDA 9, two RTX 2080 Ti to run the 'train.py' since I see you use Data Parallel module to support multi-GPUs training. However, I encountered and the trackbacks are here:

    Traceback (most recent call last): File "train.py", line 171, in main() File "train.py", line 167, in main train_net(args) File "train.py", line 64, in train_net logger=logger) File "train.py", line 103, in train alpha_out = model(img) # [N, 3, 320, 320] File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 143, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 153, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply raise output File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker output = module(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 127, in forward up4 = self.up4(up5, indices_4, unpool_shape4) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 87, in forward outputs = self.conv(outputs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/Deep-Image-Matting-v2/models.py", line 43, in forward outputs = self.cbr_unit(inputs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward input = module(input) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/home/mingfu/anaconda3/envs/tensorflow_gpu/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 320, in forward self.padding, self.dilation, self.groups) RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

    I have tested the DATA PARALLELISM using the example here and it works well.

    opened by wuyujack 3
Owner
Yang Liu
Algorithm engineer
Yang Liu
Template repository to build PyTorch projects from source on any version of PyTorch/CUDA/cuDNN.

The Ultimate PyTorch Source-Build Template Translations: 한국어 TL;DR PyTorch built from source can be x4 faster than a naïve PyTorch install. This repos

Joonhyung Lee/이준형 651 Dec 12, 2022
Arch-Net: Model Distillation for Architecture Agnostic Model Deployment

Arch-Net: Model Distillation for Architecture Agnostic Model Deployment The official implementation of Arch-Net: Model Distillation for Architecture A

MEGVII Research 22 Jan 05, 2023
Framework for estimating the structures and parameters of Bayesian networks (DAGs) at per-sample resolution

Sample-specific Bayesian Networks A framework for estimating the structures and parameters of Bayesian networks (DAGs) at per-sample or per-patient re

Caleb Ellington 1 Sep 23, 2022
Adversarial Attacks are Reversible via Natural Supervision

Adversarial Attacks are Reversible via Natural Supervision ICCV2021 Citation @InProceedings{Mao_2021_ICCV, author = {Mao, Chengzhi and Chiquier

Computer Vision Lab at Columbia University 20 May 22, 2022
Implementation of ECCV20 paper: the devil is in classification: a simple framework for long-tail object detection and instance segmentation

Implementation of our ECCV 2020 paper The Devil is in Classification: A Simple Framework for Long-tail Instance Segmentation This repo contains code o

twang 98 Sep 17, 2022
clustimage is a python package for unsupervised clustering of images.

clustimage The aim of clustimage is to detect natural groups or clusters of images. Image recognition is a computer vision task for identifying and ve

Erdogan Taskesen 52 Jan 02, 2023
a reimplementation of UnFlow in PyTorch that matches the official TensorFlow version

pytorch-unflow This is a personal reimplementation of UnFlow [1] using PyTorch. Should you be making use of this work, please cite the paper according

Simon Niklaus 134 Nov 20, 2022
GraPE is a Rust/Python library for high-performance Graph Processing and Embedding.

GraPE GraPE (Graph Processing and Embedding) is a fast graph processing and embedding library, designed to scale with big graphs and to run on both of

AnacletoLab 194 Dec 29, 2022
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

English | 简体中文 | 繁體中文 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrained mo

Hugging Face 77.2k Jan 02, 2023
UltraPose: Synthesizing Dense Pose with 1 Billion Points by Human-body Decoupling 3D Model

UltraPose: Synthesizing Dense Pose with 1 Billion Points by Human-body Decoupling 3D Model Official repository for the ICCV 2021 paper: UltraPose: Syn

MomoAILab 92 Dec 21, 2022
PyTorch and Tensorflow functional model definitions

functional-zoo Model definitions and pretrained weights for PyTorch and Tensorflow PyTorch, unlike lua torch, has autograd in it's core, so using modu

Sergey Zagoruyko 590 Dec 22, 2022
Code for the paper "Query Embedding on Hyper-relational Knowledge Graphs"

Query Embedding on Hyper-Relational Knowledge Graphs This repository contains the code used for the experiments in the paper Query Embedding on Hyper-

DimitrisAlivas 19 Jul 26, 2022
Segmentation models with pretrained backbones. Keras and TensorFlow Keras.

Python library with Neural Networks for Image Segmentation based on Keras and TensorFlow. The main features of this library are: High level API (just

Pavel Yakubovskiy 4.2k Jan 09, 2023
Benchmark spaces - Benchmarks of how well different two dimensional spaces work for clustering algorithms

benchmark_spaces Benchmarks of how well different two dimensional spaces work fo

Bram Cohen 6 May 07, 2022
code release for USENIX'22 paper `On the Security Risks of AutoML`

This project is a minimized runnable project cut from trojanzoo, which contains more datasets, models, attacks and defenses. This repo will not be mai

Ren Pang 5 Apr 19, 2022
A scikit-learn compatible neural network library that wraps PyTorch

A scikit-learn compatible neural network library that wraps PyTorch. Resources Documentation Source Code Examples To see more elaborate examples, look

4.9k Jan 03, 2023
Template repository for managing machine learning research projects built with PyTorch-Lightning

Tutorial Repository with a minimal example for showing how to deploy training across various compute infrastructure.

Sidd Karamcheti 3 Feb 11, 2022
git《USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation》(2020) GitHub: [fig2]

USD-Seg This project is an implement of paper USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation, based on FCOS detector f

Ruolin Ye 80 Nov 28, 2022
Code for the paper: Sketch Your Own GAN

Sketch Your Own GAN Project | Paper | Youtube Our method takes in one or a few hand-drawn sketches and customizes an off-the-shelf GAN to match the in

677 Dec 28, 2022
A PyTorch-based R-YOLOv4 implementation which combines YOLOv4 model and loss function from R3Det for arbitrary oriented object detection.

R-YOLOv4 This is a PyTorch-based R-YOLOv4 implementation which combines YOLOv4 model and loss function from R3Det for arbitrary oriented object detect

94 Dec 03, 2022