Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

Overview

ETSformer - Pytorch

Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

Install

$ pip install etsformer-pytorch

Usage

import torch
from etsformer_pytorch import ETSFormer

model = ETSFormer(
    time_features = 4,
    model_dim = 512,                # in paper they use 512
    embed_kernel_size = 3,          # kernel size for 1d conv for input embedding
    layers = 2,                     # number of encoder and corresponding decoder layers
    heads = 8,                      # number of exponential smoothing attention heads
    K = 4,                          # num frequencies with highest amplitude to keep (attend to)
    dropout = 0.2                   # dropout (in paper they did 0.2)
)

timeseries = torch.randn(1, 1024, 4)

pred = model(timeseries, num_steps_forecast = 32) # (1, 32, 4) - (batch, num steps forecast, num time features)

For using ETSFormer for classification, using cross attention pooling on all latents and level output

import torch
from etsformer_pytorch import ETSFormer, ClassificationWrapper

etsformer = ETSFormer(
    time_features = 1,
    model_dim = 512,
    embed_kernel_size = 3,
    layers = 2,
    heads = 8,
    K = 4,
    dropout = 0.2
)

adapter = ClassificationWrapper(
    etsformer = etsformer,
    dim_head = 32,
    heads = 16,
    dropout = 0.2,
    level_kernel_size = 5,
    num_classes = 10
)

timeseries = torch.randn(1, 1024)

logits = adapter(timeseries) # (1, 10)

Citation

@misc{woo2022etsformer,
    title   = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting}, 
    author  = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
    year    = {2022},
    eprint  = {2202.01381},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • What are your thoughts on using latents for additional classification task

    What are your thoughts on using latents for additional classification task

    Hi! I was wondering if you have thought about aggregating seasonal and growth latents for additional tasks (for example classification)? What are the possible ways to bring latents into single feature vector in your opinion? The easiest one would be just get the mean along layers and time dimensions but that seams to be too naive. Another idea I had it to use Cross Attention mechanic with single time query key to aggregate latents:

    all_latents = torch.cat([latent_growths, latent_seasonals], dim=-1)
    all_latents = rearrange(all_latents, 'b n l d -> (b l) n d')
    # q = nn.Parameter(torch.randn(all_latents_dim))
    q = repeat(q, 'd -> b 1 d', b = all_latents.shape[0])
    agg_latent = cross_attention(query=q, context=all_latents)
    agg_latent = rearrange(all_latents, '(b l) n d -> b (l n) d')
    agg_latent = agg_latent.mean(dim=1) # may be we should have done it before cross attention?
    

    Would be great to hear your thoughts

    opened by inspirit 15
  • Pre LayerNorm might be required for k,v?

    Pre LayerNorm might be required for k,v?

    https://github.com/lucidrains/ETSformer-pytorch/blob/2561053007e919409b3255eb1d0852c68799d24f/etsformer_pytorch/etsformer_pytorch.py#L440

    In my early tests I see some instability in training results, I was wondering if it might be good idea to LayerNorm latents before constructing key and values?

    opened by inspirit 5
  • growth_term calculation error

    growth_term calculation error

    https://github.com/lucidrains/ETSformer-pytorch/blob/e1d8514b44d113ead523aa6307986833e68eecc5/etsformer_pytorch/etsformer_pytorch.py#L233-L235

    It looks like you are not using growth and growth_smoothing_weightsto calculate growth_term

    opened by inspirit 4
  • Backward gradient error

    Backward gradient error

    Hello,

    i was trying to run the provided class and see following error: Function ScatterBackward0 returned an invalid gradient at index 1 - got [64, 4, 128] but expected shape compatible with [64, 33, 128]

    model = ETSFormer(
                time_features = 9,
                model_dim = 128,
                embed_kernel_size = 3,
                layers = 2,
                heads = 4,
                K = 4,
                dropout = 0.2
            )
    

    input = torch.rand(64, 64, 9) x = model(input, num_steps_forecast = 16)

    opened by inspirit 3
  • Does ETS-Former allow adding features

    Does ETS-Former allow adding features

    @lucidrains Thanks for making the code of the model available!

    In your paper, you state that the model infers seasonal patterns itself, so that there is no need to add time features like week, month, etc.

    Still, to increase the applicability of your approach, does the current implementation allow to add any (time-invariant and time-varying) features, e.g., categorical or numeric?

    opened by StatMixedML 2
  • wrong order of arguments

    wrong order of arguments

    https://github.com/lucidrains/ETSformer-pytorch/blob/2e0d465576c15fc8d84c4673f93fdd71d45b799c/etsformer_pytorch/etsformer_pytorch.py#L327

    you pass latents on wrong order to Level module: according to forward method first should be growth and then seasonal

    opened by inspirit 1
  • Clarification regarding data pre-processing

    Clarification regarding data pre-processing

    Hello,

    I was trying to run the ETSformer for ETT dataset. The paper mentions that the dataset is split as 60/20/20 for train, validation and test. Could you give some insight as to how the dataset split is happening in the code.

    Thank you.

    opened by vageeshmaiya 2
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Implementation of EMNLP 2017 Paper "Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog" using PyTorch and ParlAI

Language Emergence in Multi Agent Dialog Code for the Paper Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog Satwik Kottur, José M.

Karan Desai 105 Nov 25, 2022
ICLR 2021: Pre-Training for Context Representation in Conversational Semantic Parsing

SCoRe: Pre-Training for Context Representation in Conversational Semantic Parsing This repository contains code for the ICLR 2021 paper "SCoRE: Pre-Tr

Microsoft 28 Oct 02, 2022
MetaDrive: Composing Diverse Scenarios for Generalizable Reinforcement Learning

MetaDrive: Composing Diverse Driving Scenarios for Generalizable RL [ Documentation | Demo Video ] MetaDrive is a driving simulator with the following

DeciForce: Crossroads of Machine Perception and Autonomy 276 Jan 04, 2023
Code for our ACL 2021 paper - ConSERT: A Contrastive Framework for Self-Supervised Sentence Representation Transfer

ConSERT Code for our ACL 2021 paper - ConSERT: A Contrastive Framework for Self-Supervised Sentence Representation Transfer Requirements torch==1.6.0

Yan Yuanmeng 478 Dec 25, 2022
Iterative Training: Finding Binary Weight Deep Neural Networks with Layer Binarization

Iterative Training: Finding Binary Weight Deep Neural Networks with Layer Binarization This repository contains the source code for the paper (link wi

Rakuten Group, Inc. 0 Nov 19, 2021
a spacial-temporal pattern detection system for home automation

Argos a spacial-temporal pattern detection system for home automation. Based on OpenCV and Tensorflow, can run on raspberry pi and notify HomeAssistan

Angad Singh 133 Jan 05, 2023
A minimal implementation of face-detection models using flask, gunicorn, nginx, docker, and docker-compose

Face-Detection-flask-gunicorn-nginx-docker This is a simple implementation of dockerized face-detection restful-API implemented with flask, Nginx, and

Pooya-Mohammadi 30 Dec 17, 2022
3D detection and tracking viewer (visualization) for kitti & waymo dataset

3D detection and tracking viewer (visualization) for kitti & waymo dataset

222 Jan 08, 2023
An implementation of the efficient attention module.

Efficient Attention An implementation of the efficient attention module. Description Efficient attention is an attention mechanism that substantially

Shen Zhuoran 194 Dec 15, 2022
Bayesian optimization in PyTorch

BoTorch is a library for Bayesian Optimization built on PyTorch. BoTorch is currently in beta and under active development! Why BoTorch ? BoTorch Prov

2.5k Dec 31, 2022
Time-Optimal Planning for Quadrotor Waypoint Flight

Time-Optimal Planning for Quadrotor Waypoint Flight This is an example implementation of the paper "Time-Optimal Planning for Quadrotor Waypoint Fligh

Robotics and Perception Group 38 Dec 02, 2022
learned_optimization: Training and evaluating learned optimizers in JAX

learned_optimization: Training and evaluating learned optimizers in JAX learned_optimization is a research codebase for training learned optimizers. I

Google 533 Dec 30, 2022
MBPO (paper: When to trust your model: Model-based policy optimization) in offline RL settings

offline-MBPO This repository contains the code of a version of model-based RL algorithm MBPO, which is modified to perform in offline RL settings Pape

LxzGordon 1 Oct 24, 2021
PyTorch implementation of Wide Residual Networks with 1-bit weights by McDonnell (ICLR 2018)

1-bit Wide ResNet PyTorch implementation of training 1-bit Wide ResNets from this paper: Training wide residual networks for deployment using a single

Sergey Zagoruyko 122 Dec 07, 2022
Using deep learning to predict gene structures of the coding genes in DNA sequences of Arabidopsis thaliana

DeepGeneAnnotator: A tool to annotate the gene in the genome The master thesis of the "Using deep learning to predict gene structures of the coding ge

Ching-Tien Wang 3 Sep 09, 2022
COD-Rank-Localize-and-Segment (CVPR2021)

COD-Rank-Localize-and-Segment (CVPR2021) Simultaneously Localize, Segment and Rank the Camouflaged Objects Full camouflage fixation training dataset i

JingZhang 52 Dec 20, 2022
Stochastic Tensor Optimization for Robot Motion - A GPU Robot Motion Toolkit

STORM Stochastic Tensor Optimization for Robot Motion - A GPU Robot Motion Toolkit [Install Instructions] [Paper] [Website] This package contains code

NVIDIA Research Projects 101 Dec 12, 2022
Unofficial PyTorch reimplementation of the paper Swin Transformer V2: Scaling Up Capacity and Resolution

PyTorch reimplementation of the paper Swin Transformer V2: Scaling Up Capacity and Resolution [arXiv 2021].

Christoph Reich 122 Dec 12, 2022
Official code release for: EditGAN: High-Precision Semantic Image Editing

Official code release for: EditGAN: High-Precision Semantic Image Editing

565 Jan 05, 2023
Trained on Simulated Data, Tested in the Real World

Trained on Simulated Data, Tested in the Real World

livox 43 Nov 18, 2022