Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

Overview

ETSformer - Pytorch

Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

Install

$ pip install etsformer-pytorch

Usage

import torch
from etsformer_pytorch import ETSFormer

model = ETSFormer(
    time_features = 4,
    model_dim = 512,                # in paper they use 512
    embed_kernel_size = 3,          # kernel size for 1d conv for input embedding
    layers = 2,                     # number of encoder and corresponding decoder layers
    heads = 8,                      # number of exponential smoothing attention heads
    K = 4,                          # num frequencies with highest amplitude to keep (attend to)
    dropout = 0.2                   # dropout (in paper they did 0.2)
)

timeseries = torch.randn(1, 1024, 4)

pred = model(timeseries, num_steps_forecast = 32) # (1, 32, 4) - (batch, num steps forecast, num time features)

For using ETSFormer for classification, using cross attention pooling on all latents and level output

import torch
from etsformer_pytorch import ETSFormer, ClassificationWrapper

etsformer = ETSFormer(
    time_features = 1,
    model_dim = 512,
    embed_kernel_size = 3,
    layers = 2,
    heads = 8,
    K = 4,
    dropout = 0.2
)

adapter = ClassificationWrapper(
    etsformer = etsformer,
    dim_head = 32,
    heads = 16,
    dropout = 0.2,
    level_kernel_size = 5,
    num_classes = 10
)

timeseries = torch.randn(1, 1024)

logits = adapter(timeseries) # (1, 10)

Citation

@misc{woo2022etsformer,
    title   = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting}, 
    author  = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
    year    = {2022},
    eprint  = {2202.01381},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • What are your thoughts on using latents for additional classification task

    What are your thoughts on using latents for additional classification task

    Hi! I was wondering if you have thought about aggregating seasonal and growth latents for additional tasks (for example classification)? What are the possible ways to bring latents into single feature vector in your opinion? The easiest one would be just get the mean along layers and time dimensions but that seams to be too naive. Another idea I had it to use Cross Attention mechanic with single time query key to aggregate latents:

    all_latents = torch.cat([latent_growths, latent_seasonals], dim=-1)
    all_latents = rearrange(all_latents, 'b n l d -> (b l) n d')
    # q = nn.Parameter(torch.randn(all_latents_dim))
    q = repeat(q, 'd -> b 1 d', b = all_latents.shape[0])
    agg_latent = cross_attention(query=q, context=all_latents)
    agg_latent = rearrange(all_latents, '(b l) n d -> b (l n) d')
    agg_latent = agg_latent.mean(dim=1) # may be we should have done it before cross attention?
    

    Would be great to hear your thoughts

    opened by inspirit 15
  • Pre LayerNorm might be required for k,v?

    Pre LayerNorm might be required for k,v?

    https://github.com/lucidrains/ETSformer-pytorch/blob/2561053007e919409b3255eb1d0852c68799d24f/etsformer_pytorch/etsformer_pytorch.py#L440

    In my early tests I see some instability in training results, I was wondering if it might be good idea to LayerNorm latents before constructing key and values?

    opened by inspirit 5
  • growth_term calculation error

    growth_term calculation error

    https://github.com/lucidrains/ETSformer-pytorch/blob/e1d8514b44d113ead523aa6307986833e68eecc5/etsformer_pytorch/etsformer_pytorch.py#L233-L235

    It looks like you are not using growth and growth_smoothing_weightsto calculate growth_term

    opened by inspirit 4
  • Backward gradient error

    Backward gradient error

    Hello,

    i was trying to run the provided class and see following error: Function ScatterBackward0 returned an invalid gradient at index 1 - got [64, 4, 128] but expected shape compatible with [64, 33, 128]

    model = ETSFormer(
                time_features = 9,
                model_dim = 128,
                embed_kernel_size = 3,
                layers = 2,
                heads = 4,
                K = 4,
                dropout = 0.2
            )
    

    input = torch.rand(64, 64, 9) x = model(input, num_steps_forecast = 16)

    opened by inspirit 3
  • Does ETS-Former allow adding features

    Does ETS-Former allow adding features

    @lucidrains Thanks for making the code of the model available!

    In your paper, you state that the model infers seasonal patterns itself, so that there is no need to add time features like week, month, etc.

    Still, to increase the applicability of your approach, does the current implementation allow to add any (time-invariant and time-varying) features, e.g., categorical or numeric?

    opened by StatMixedML 2
  • wrong order of arguments

    wrong order of arguments

    https://github.com/lucidrains/ETSformer-pytorch/blob/2e0d465576c15fc8d84c4673f93fdd71d45b799c/etsformer_pytorch/etsformer_pytorch.py#L327

    you pass latents on wrong order to Level module: according to forward method first should be growth and then seasonal

    opened by inspirit 1
  • Clarification regarding data pre-processing

    Clarification regarding data pre-processing

    Hello,

    I was trying to run the ETSformer for ETT dataset. The paper mentions that the dataset is split as 60/20/20 for train, validation and test. Could you give some insight as to how the dataset split is happening in the code.

    Thank you.

    opened by vageeshmaiya 2
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Ludwig is a toolbox that allows to train and evaluate deep learning models without the need to write code.

Translated in 🇰🇷 Korean/ Ludwig is a toolbox that allows users to train and test deep learning models without the need to write code. It is built on

Ludwig 8.7k Jan 05, 2023
Block Sparse movement pruning

Movement Pruning: Adaptive Sparsity by Fine-Tuning Magnitude pruning is a widely used strategy for reducing model size in pure supervised learning; ho

Hugging Face 54 Dec 20, 2022
TDN: Temporal Difference Networks for Efficient Action Recognition

TDN: Temporal Difference Networks for Efficient Action Recognition Overview We release the PyTorch code of the TDN(Temporal Difference Networks).

Multimedia Computing Group, Nanjing University 326 Dec 13, 2022
A booklet on machine learning systems design with exercises

Machine Learning Systems Design Read this booklet here. This booklet covers four main steps of designing a machine learning system: Project setup Data

Chip Huyen 7.6k Jan 08, 2023
A modular, research-friendly framework for high-performance and inference of sequence models at many scales

T5X T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of

Google Research 1.1k Jan 08, 2023
Lane follower: Lane-detector (OpenCV) + Object-detector (YOLO5) + CAN-bus

Lane Follower This code is for the lane follower, including perception and control, as shown below. Environment Hardware Industrial Camera Intel-NUC(1

Siqi Fan 3 Jul 07, 2022
Teaching end to end workflow of deep learning

Deep-Education This repository is now available for public use for teaching end to end workflow of deep learning. This implies that learners/researche

Data Lab at College of William and Mary 2 Sep 26, 2022
Repository for "Improving evidential deep learning via multi-task learning," published in AAAI2022

Improving evidential deep learning via multi task learning It is a repository of AAAI2022 paper, “Improving evidential deep learning via multi-task le

deargen 11 Nov 19, 2022
Code for "Layered Neural Rendering for Retiming People in Video."

Layered Neural Rendering in PyTorch This repository contains training code for the examples in the SIGGRAPH Asia 2020 paper "Layered Neural Rendering

Google 154 Dec 16, 2022
一个多模态内容理解算法框架,其中包含数据处理、预训练模型、常见模型以及模型加速等模块。

Overview 架构设计 插件介绍 安装使用 框架简介 方便使用,支持多模态,多任务的统一训练框架 能力列表: bert + 分类任务 自定义任务训练(插件注册) 框架设计 框架采用分层的思想组织模型训练流程。 DATA 层负责读取用户数据,根据 field 管理数据。 Parser 层负责转换原

Tencent 265 Dec 22, 2022
Self-attentive task GAN for space domain awareness data augmentation.

SATGAN TODO: update the article URL once published. Article about this implemention The self-attentive task generative adversarial network (SATGAN) le

Nathan 2 Mar 24, 2022
Code for 2021 NeurIPS --- Towards Multi-Grained Explainability for Graph Neural Networks

ReFine: Multi-Grained Explainability for GNNs This is the official code for Towards Multi-Grained Explainability for Graph Neural Networks (NeurIPS 20

Shirley (Ying-Xin) Wu 47 Dec 16, 2022
A FAIR dataset of TCV experimental results for validating edge/divertor turbulence models.

TCV-X21 validation for divertor turbulence simulations Quick links Intro Welcome to TCV-X21. We're glad you've found us! This repository is designed t

0 Dec 18, 2021
AI virtual gym is an AI program which can be used to exercise and can be used to see if we are doing the exercises

AI virtual gym is an AI program which can be used to exercise and can be used to see if we are doing the exercises

4 Feb 13, 2022
Planning from Pixels in Environments with Combinatorially Hard Search Spaces -- NeurIPS 2021

PPGS: Planning from Pixels in Environments with Combinatorially Hard Search Spaces Environment Setup We recommend pipenv for creating and managing vir

Autonomous Learning Group 11 Jun 26, 2022
quantize aware training package for NCNN on pytorch

ncnnqat ncnnqat is a quantize aware training package for NCNN on pytorch. Table of Contents ncnnqat Table of Contents Installation Usage Code Examples

62 Nov 23, 2022
Implementation of ICCV2021(Oral) paper - VMNet: Voxel-Mesh Network for Geodesic-aware 3D Semantic Segmentation

VMNet: Voxel-Mesh Network for Geodesic-Aware 3D Semantic Segmentation Created by Zeyu HU Introduction This work is based on our paper VMNet: Voxel-Mes

HU Zeyu 82 Dec 27, 2022
(CVPR2021) Kaleido-BERT: Vision-Language Pre-training on Fashion Domain

Kaleido-BERT: Vision-Language Pre-training on Fashion Domain Mingchen Zhuge*, Dehong Gao*, Deng-Ping Fan#, Linbo Jin, Ben Chen, Haoming Zhou, Minghui

250 Jan 08, 2023
PyTorch Implementation of Unsupervised Depth Completion with Calibrated Backprojection Layers (ORAL, ICCV 2021)

Unsupervised Depth Completion with Calibrated Backprojection Layers PyTorch implementation of Unsupervised Depth Completion with Calibrated Backprojec

80 Dec 13, 2022
Cross-platform CLI tool to generate your Github profile's stats and summary.

ghs Cross-platform CLI tool to generate your Github profile's stats and summary. Preview Hop on to examples for other usecases. Jump to: Installation

HackerRank 134 Dec 20, 2022