PyTorch implementation of Soft-DTW: a Differentiable Loss Function for Time-Series in CUDA

Overview

Soft DTW Loss Function for PyTorch in CUDA

This is a Pytorch Implementation of Soft-DTW: a Differentiable Loss Function for Time-Series which is batch supported computation, CUDA-friendly, and feasible to use as a final loss. I can confirm that you can train a (sequential) model with this as a final loss! The following image shows training logs of a TTS model using the Soft-DTW Loss Function.

There are some previous implementations:

  1. mblondel's soft-dtw
  2. lyprince's sdtw_pytorch
  3. Maghoumi's pytorch-softdtw-cuda

But they are either not supported by CUDA-friendly batch computation or not considering the jacobean w.r.t input matrix, which is necessary to be used as a final loss in recent deep learning frameworks. In the current implementation, all conditions are satisfied.

Usage

Same as Maghoumi's pytorch-softdtw-cuda:

from sdtw_cuda_loss import SoftDTW

# Create the sequences
batch_size, len_x, len_y, dims = 8, 15, 12, 5
x = torch.rand((batch_size, len_x, dims), requires_grad=True)
y = torch.rand((batch_size, len_y, dims))

# Create the "criterion" object
sdtw = SoftDTW(use_cuda=True, gamma=0.1)

# Compute the loss value
loss = sdtw(x, y)  # Just like any torch.nn.xyzLoss()

# Aggregate and call backward()
loss.mean().backward()

But the backward will compute the gradient w.r.t input target sequence x (which is not considered in the previous work).

Note

In the current implementation, only use_cuda=True is supported. But you can easily implement the CPU version as in Maghoumi's pytorch-softdtw-cuda.

Citation

@misc{lee2021soft_dtw_loss,
  author = {Lee, Keon},
  title = {Soft-DTW-Loss},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/keonlee9420/Soft-DTW-Loss}}
}
You might also like...
Seach Losses of our paper 'Loss Function Discovery for Object Detection via Convergence-Simulation Driven Search', accepted by ICLR 2021.
Seach Losses of our paper 'Loss Function Discovery for Object Detection via Convergence-Simulation Driven Search', accepted by ICLR 2021.

CSE-Autoloss Designing proper loss functions for vision tasks has been a long-standing research direction to advance the capability of existing models

Multi-scale discriminator feature-wise loss function

Multi-Scale Discriminative Feature Loss This repository provides code for Multi-Scale Discriminative Feature (MDF) loss for image reconstruction algor

clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation
clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation

README clDice - a Novel Topology-Preserving Loss Function for Tubular Structure Segmentation CVPR 2021 Authors: Suprosanna Shit and Johannes C. Paetzo

HistoSeg : Quick attention with multi-loss function for multi-structure segmentation in digital histology images

HistoSeg : Quick attention with multi-loss function for multi-structure segmentation in digital histology images Histological Image Segmentation This

Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation

SSWS-loss_function_based_on_MS-TCN Supervised Sliding Window Smoothing Loss Function Based on MS-TCN for Video Segmentation Supervised Sliding Window

[CVPR 2022] Official code for the paper:
[CVPR 2022] Official code for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration"

MDCA Calibration This is the official PyTorch implementation for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved

Official implementation of "DSP: Dual Soft-Paste for Unsupervised Domain Adaptive Semantic Segmentation"

DSP Official implementation of "DSP: Dual Soft-Paste for Unsupervised Domain Adaptive Semantic Segmentation". Accepted by ACM Multimedia 2021. Authors

Softlearning is a reinforcement learning framework for training maximum entropy policies in continuous domains. Includes the official implementation of the Soft Actor-Critic algorithm.

Softlearning Softlearning is a deep reinforcement learning toolbox for training maximum entropy policies in continuous domains. The implementation is

Decorators for maximizing memory utilization with PyTorch & CUDA

torch-max-mem This package provides decorators for memory utilization maximization with PyTorch and CUDA by starting with a maximum parameter size and

Comments
  • Does this supports multi-gpu training?

    Does this supports multi-gpu training?

    Thanks for sharing impl of soft-dtw, I can use it in single-gpu env,but can't use it in multi-gpu envs.Currently, it doesn't support multi-gpu training?

    opened by mayfool 2
  • how to use dtw-loss to fit a curve?

    how to use dtw-loss to fit a curve?

    hello, I tried to fit a curve (discrete points) using Soft-DTW-Loss as a loss function. But the loss does not converge to the exact result in the end. Is there something wrong with the way I am using it? The code is as follows:

    if name == "main":

    batch_size = 1
    len_x = 15
    len_predict = 10
    dims = 1
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    x = torch.unsqueeze(torch.linspace(1, 4, steps=len_x, requires_grad=True), dim=0)
    y = x ** 2
    y = y.view(1, len_x, 1)
    x = x.view(1, len_x, 1)
    
    #(batch,length,dims)---->(1,15,2)
    truth_points = torch.cat((y, x), dim=2).cuda()
    
    #(1,20)
    input = torch.unsqueeze(torch.linspace(1, 4, steps=len_predict*2, requires_grad=True), dim=0).cuda()
    
    
    class testNN(torch.nn.Module):
        def __init__(self):
            super(testNN, self).__init__()
            self.layer = nn.Sequential(
                nn.Linear(20, 50),
                nn.ReLU(),
                nn.Linear(50, 200),
                nn.ReLU(),
                nn.Linear(200, 50),
                nn.ReLU(),
                nn.Linear(50, 20),
                nn.ReLU(),
            )
        def forward(self, x):
            x = self.layer(x)
            return x
    
    
    test = testNN()
    test = test.to(device)
    
    loss_function = SoftDTW(use_cuda=True, gamma=0.01, normalize=False)
    optimizer = torch.optim.Adam(test.parameters(), lr=0.01)
    
    
    for epoch in range(1000):
    
    
        predict = test(input)
        #(1,20) reshape to (1,10,2)
        predict = predict.reshape(1, len_predict, 2)
        loss = loss_function(predict, truth_points)
        optimizer.zero_grad()
        loss.mean().backward(retain_graph=True)
        optimizer.step()
    
    
        if epoch % 10 == 0:
            print("epoch : %d | loss : %f" % (epoch, loss))
            plt_predict = predict.cpu().detach().numpy()
            # print(plt_predict)
            plt_predict = plt_predict.reshape(1, len_predict, 2)
            print(plt_predict[0, :, 0])
            print(plt_predict[0, :, 1])
    
    opened by visionlyx 0
Releases(v1.0.0)
Owner
Keon Lee
Expressive Speech Synthesis | Conversational AI | Open-domain Dialog | NLP | Generative Models | Empathic Computing | HCI
Keon Lee
HDR Video Reconstruction: A Coarse-to-fine Network and A Real-world Benchmark Dataset (ICCV 2021)

Code for HDR Video Reconstruction HDR Video Reconstruction: A Coarse-to-fine Network and A Real-world Benchmark Dataset (ICCV 2021) Guanying Chen, Cha

Guanying Chen 64 Nov 19, 2022
Run Effective Large Batch Contrastive Learning on Limited Memory GPU

Gradient Cache Gradient Cache is a simple technique for unlimitedly scaling contrastive learning batch far beyond GPU memory constraint. This means tr

Luyu Gao 198 Dec 29, 2022
Model Zoo for MindSpore

Welcome to the Model Zoo for MindSpore In order to facilitate developers to enjoy the benefits of MindSpore framework, we will continue to add typical

MindSpore 226 Jan 07, 2023
A smaller subset of 10 easily classified classes from Imagenet, and a little more French

Imagenette 🎶 Imagenette, gentille imagenette, Imagenette, je te plumerai. 🎶 (Imagenette theme song thanks to Samuel Finlayson) NB: Versions of Image

fast.ai 718 Jan 01, 2023
Exe-to-xlsm - Simple script to create VBscript of exe and inject to xlsm

🎁 Exe To Office Executable file injection to Office documents: .xlsm, .docm, .p

3 Jan 25, 2022
PyBrain - Another Python Machine Learning Library.

PyBrain -- the Python Machine Learning Library =============================================== INSTALLATION ------------ Quick answer: make sure you

2.8k Dec 31, 2022
Unsupervised Image-to-Image Translation

UNIT: UNsupervised Image-to-image Translation Networks Imaginaire Repository We have a reimplementation of the UNIT method that is more performant. It

Ming-Yu Liu 劉洺堉 1.9k Dec 26, 2022
This MVP data web app uses the Streamlit framework and Facebook's Prophet forecasting package to generate a dynamic forecast from your own data.

📈 Automated Time Series Forecasting Background: This MVP data web app uses the Streamlit framework and Facebook's Prophet forecasting package to gene

Zach Renwick 42 Jan 04, 2023
Where2Act: From Pixels to Actions for Articulated 3D Objects

Where2Act: From Pixels to Actions for Articulated 3D Objects The Proposed Where2Act Task. Given as input an articulated 3D object, we learn to propose

Kaichun Mo 69 Nov 28, 2022
deep learning for image processing including classification and object-detection etc.

深度学习在图像处理中的应用教程 前言 本教程是对本人研究生期间的研究内容进行整理总结,总结的同时也希望能够帮助更多的小伙伴。后期如果有学习到新的知识也会与大家一起分享。 本教程会以视频的方式进行分享,教学流程如下: 1)介绍网络的结构与创新点 2)使用Pytorch进行网络的搭建与训练 3)使用Te

WuZhe 13.6k Jan 04, 2023
Course on computational design, non-linear optimization, and dynamics of soft systems at UIUC.

Computational Design and Dynamics of Soft Systems · This is a repository that contains the source code for generating the lecture notes, handouts, exe

Tejaswin Parthasarathy 4 Jul 21, 2022
Code for ICCV 2021 paper "HuMoR: 3D Human Motion Model for Robust Pose Estimation"

Code for ICCV 2021 paper "HuMoR: 3D Human Motion Model for Robust Pose Estimation"

Davis Rempe 367 Dec 24, 2022
POT : Python Optimal Transport

POT: Python Optimal Transport This open source Python library provide several solvers for optimization problems related to Optimal Transport for signa

Python Optimal Transport 1.7k Dec 31, 2022
The Multi-Mission Maximum Likelihood framework (3ML)

PyPi Conda The Multi-Mission Maximum Likelihood framework (3ML) A framework for multi-wavelength/multi-messenger analysis for astronomy/astrophysics.

The Multi-Mission Maximum Likelihood (3ML) 62 Dec 30, 2022
Picasso: A CUDA-based Library for Deep Learning over 3D Meshes

The Picasso Library is intended for complex real-world applications with large-scale surfaces, while it also performs impressively on the small-scale applications over synthetic shape manifolds. We h

97 Dec 01, 2022
Official code of paper "PGT: A Progressive Method for Training Models on Long Videos" on CVPR2021

PGT Code for paper PGT: A Progressive Method for Training Models on Long Videos. Install Run pip install -r requirements.txt. Run python setup.py buil

Bo Pang 27 Mar 30, 2022
ConformalLayers: A non-linear sequential neural network with associative layers

ConformalLayers: A non-linear sequential neural network with associative layers ConformalLayers is a conformal embedding of sequential layers of Convo

Prograf-UFF 5 Sep 28, 2022
A free, multiplatform SDK for real-time facial motion capture using blendshapes, and rigid head pose in 3D space from any RGB camera, photo, or video.

mocap4face by Facemoji mocap4face by Facemoji is a free, multiplatform SDK for real-time facial motion capture based on Facial Action Coding System or

Facemoji 591 Dec 27, 2022
Cl datasets - PyTorch image dataloaders and utility functions to load datasets for supervised continual learning

Continual learning datasets Introduction This repository contains PyTorch image

berjaoui 5 Aug 28, 2022
Fantasy Points Prediction and Dream Team Formation

Fantasy-Points-Prediction-and-Dream-Team-Formation Collected Data from open source resources that have over 100 Parameters for predicting cricket play

Akarsh Singh 2 Sep 13, 2022