Repository for "Improving evidential deep learning via multi-task learning," published in AAAI2022

Overview

Improving evidential deep learning via multi task learning

It is a repository of AAAI2022 paper, “Improving evidential deep learning via multi-task learning”, by Dongpin Oh and Bonggun Shin.

This repository contains the code to reproduce the Multi-task evidential neural network (MT-ENet), which uses the Lipschitz MSE loss function as the additional loss function of the evidential regression network (ENet). The Lipschitz MSE loss function can improve the accuracy of the ENet while preserving its uncertainty estimation capability, by avoiding gradient conflict with the NLL loss function—the original loss function of the ENet.

drawing

Setup

Please refer to "requirements.txt" for requring packages of this repo.

pip install -r requirements.txt

Training the ENet with the Lipschitz-MSE loss: example

from mtevi.mtevi import EvidentialMarginalLikelihood, EvidenceRegularizer, modified_mse
...
net = EvidentialNetwork() ## Evidential regression network
nll_loss = EvidentialMarginalLikelihood() ## original loss, NLL loss
reg = EvidenceRegularizer() ## evidential regularizer
mmse_loss = modified_mse ## lipschitz MSE loss
...
for inputs, labels in dataloader:
	gamma, nu, alpha, beta = net(inputs)
	loss = nll_loss(gamma, nu, alpha, beta, labels)
	loss += reg(gamma, nu, alpha, beta, labels)
	loss += mmse_loss(gamma, nu, alpha, beta, labels)
	loss.backward()	

Quick start

  • Synthetic data experiment.
python synthetic_exp.py
  • UCI regression benchmark experiments.
python uci_exp_norm -p energy
  • Drug target affinity (DTA) regression task on KIBA and Davis datasets.
python train_evinet.py -o test --type davis -f 0 --evi # ENet
python train_evinet.py -o test --type davis -f 0  # MT-ENet
  • Gradient conflict experiment on the DTA benchmarks
python check_conflict.py --type davis -f 0 # Conflict between the Lipschitz MSE (proposed) and NLL loss. 
python check_conflict.py --type davis -f 0 --abl # Conflict between the simple MSE loss and NLL loss.

Characteristic of the Lipschitz MSE loss

drawing

  • The Lipschitz MSE loss function can support training the ENet to more accurately predicts target values.
  • It regularizes its gradient to prevent gradient conflict with the NLL loss--the original loss function--if the NLL loss increases predictive uncertainty of the ENet.
  • Please check our paper for details.
Owner
deargen
deargen
Aydin is a user-friendly, feature-rich, and fast image denoising tool

Aydin is a user-friendly, feature-rich, and fast image denoising tool that provides a number of self-supervised, auto-tuned, and unsupervised image denoising algorithms.

Royer Lab 99 Dec 14, 2022
Tutorial in Python targeted at Epidemiologists. Will discuss the basics of analysis in Python 3

Python-for-Epidemiologists This repository is an introduction to epidemiology analyses in Python. Additionally, the tutorials for my library zEpid are

Paul Zivich 120 Nov 17, 2022
PyTorch code for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

Salesforce 1.3k Dec 31, 2022
Official repo of the paper "Surface Form Competition: Why the Highest Probability Answer Isn't Always Right"

Surface Form Competition This is the official repo of the paper "Surface Form Competition: Why the Highest Probability Answer Isn't Always Right" We p

Peter West 46 Dec 23, 2022
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations

Thalles Silva 1.7k Dec 28, 2022
3D ResNets for Action Recognition (CVPR 2018)

3D ResNets for Action Recognition Update (2020/4/13) We published a paper on arXiv. Hirokatsu Kataoka, Tenga Wakamiya, Kensho Hara, and Yutaka Satoh,

Kensho Hara 3.5k Jan 06, 2023
Tensorflow implementation for "Improved Transformer for High-Resolution GANs" (NeurIPS 2021).

HiT-GAN Official TensorFlow Implementation HiT-GAN presents a Transformer-based generator that is trained based on Generative Adversarial Networks (GA

Google Research 78 Oct 31, 2022
PyTorch implementations of the beta divergence loss.

Beta Divergence Loss - PyTorch Implementation This repository contains code for a PyTorch implementation of the beta divergence loss. Dependencies Thi

Billy Carson 7 Nov 09, 2022
【CVPR 2021, Variational Inference Framework, PyTorch】 From Rain Generation to Rain Removal

From Rain Generation to Rain Removal (CVPR2021) Hong Wang, Zongsheng Yue, Qi Xie, Qian Zhao, Yefeng Zheng, and Deyu Meng [PDF&&Supplementary Material]

Hong Wang 48 Nov 23, 2022
PyTorch implementation of the NIPS-17 paper "Poincaré Embeddings for Learning Hierarchical Representations"

Poincaré Embeddings for Learning Hierarchical Representations PyTorch implementation of Poincaré Embeddings for Learning Hierarchical Representations

Facebook Research 1.6k Dec 25, 2022
Quantum-enhanced transformer neural network

Example of a Quantum-enhanced transformer neural network Get the code: git clone https://github.com/rdisipio/qtransformer.git cd qtransformer Create

Riccardo Di Sipio 61 Nov 08, 2022
PyTorch CZSL framework containing GQA, the open-world setting, and the CGE and CompCos methods.

Compositional Zero-Shot Learning This is the official PyTorch code of the CVPR 2021 works Learning Graph Embeddings for Compositional Zero-shot Learni

EML Tübingen 70 Dec 27, 2022
Efficient training of deep recommenders on cloud.

HybridBackend Introduction HybridBackend is a training framework for deep recommenders which bridges the gap between evolving cloud infrastructure and

Alibaba 111 Dec 23, 2022
Code for paper Novel View Synthesis via Depth-guided Skip Connections

Novel View Synthesis via Depth-guided Skip Connections Code for paper Novel View Synthesis via Depth-guided Skip Connections @InProceedings{Hou_2021_W

8 Mar 14, 2022
A simple interface for editing natural photos with generative neural networks.

Neural Photo Editor A simple interface for editing natural photos with generative neural networks. This repository contains code for the paper "Neural

Andy Brock 2.1k Dec 29, 2022
Python implementation of a live deep learning based age/gender/expression recognizer

TUT live age estimator Python implementation of a live deep learning based age/gender/smile/celebrity twin recognizer. All components use convolutiona

Heikki Huttunen 80 Nov 21, 2022
QICK: Quantum Instrumentation Control Kit

QICK: Quantum Instrumentation Control Kit The QICK is a kit of firmware and software to use the Xilinx RFSoC to control quantum systems. It consists o

81 Dec 15, 2022
PushForKiCad - AISLER Push for KiCad EDA

AISLER Push for KiCad Push your layout to AISLER with just one click for instant

AISLER 31 Dec 29, 2022
Improving XGBoost survival analysis with embeddings and debiased estimators

xgbse: XGBoost Survival Embeddings "There are two cultures in the use of statistical modeling to reach conclusions from data

Loft 242 Dec 30, 2022
Quasi-Dense Similarity Learning for Multiple Object Tracking, CVPR 2021 (Oral)

Quasi-Dense Tracking This is the offical implementation of paper Quasi-Dense Similarity Learning for Multiple Object Tracking. We present a trailer th

ETH VIS Research Group 327 Dec 27, 2022