TimeSHAP explains Recurrent Neural Network predictions.

Overview

TimeSHAP

TimeSHAP is a model-agnostic, recurrent explainer that builds upon KernelSHAP and extends it to the sequential domain. TimeSHAP computes event/timestamp- feature-, and cell-level attributions. As sequences can be arbitrarily long, TimeSHAP also implements a pruning algorithm based on Shapley Values, that finds a subset of consecutive, recent events that contribute the most to the decision.

This repository is the code implementation of the TimeSHAP algorithm present in the paper TimeSHAP: Explaining Recurrent Models through Sequence Perturbations published at KDD 2021.

Links to the paper here, and to the video presentation here.

Install TimeSHAP

Clone the repository into a local directory using:

git clone https://github.com/feedzai/timeshap.git

Install the package using pip:

pip install timeshap

To test your installation, start a Python session in your terminal using

python

And import TimeSHAP

import timeshap

TimeSHAP in 30 seconds

Inputs

  • Model being explained;
  • Instance(s) to explain;
  • Background instance.

Outputs

  • Local pruning output; (explaining a single instance)
  • Local event explanations; (explaining a single instance)
  • Local feature explanations; (explaining a single instance)
  • Global pruning statistics; (explaining multiple instances)
  • Global event explanations; (explaining multiple instances)
  • Global feature explanations; (explaining multiple instances)

Model Interface

In order for TimeSHAP to explain a model, an entry point must be provided. This Callable entry point must receive a 3-D numpy array, (#sequences; #sequence length; #features) and return a 2-D numpy array (#sequences; 1) with the corresponding score of each sequence. In addition, to make TimeSHAP more optimized, it is possible to return the hidden state of the model together with the score (if applicable), although this is optional.

TimeSHAP is able to explain any black-box model as long as it complies with the previously described interface, including both PyTorch and TensorFlow models, both examplified in our tutorials (PyTorch, TensorFlow).

Example provided in our tutorials:

  • TensorFLow
model = tf.keras.models.Model(inputs=inputs, outputs=ff2)
f = lambda x: model.predict(x)
  • Pytorch - (Example where model receives and returns hidden states)
model_wrapped = TorchModelWrapper(model)
f_hs = lambda x, y=None: model_wrapped.predict_last_hs(x, y)
Model Wrappers

In order to facilitate the interface between models and TimeSHAP, TimeSHAP implements ModelWrappers. These wrappers, used on the PyTorch tutorial notebook, allow for greater flexibility of explained models as they allow:

  • Batching logic: useful when using very large inputs or NSamples, which cannot fit on GPU memory, and therefore batching mechanisms are required;
  • Input format/type: useful when your model does not work with numpy arrays. This is the case of our provided PyToch example;

TimeSHAP Explanation Methods

TimeSHAP offers several methods to use depending on the desired explanations. Local methods provide detailed view of a model decision corresponding to a specific sequence being explained. Global methods aggregate local explanations of a given dataset to present a global view of the model.

Local Explanations

Pruning

local_pruning() performs the pruning algorithm on a given sequence with a given user defined tolerance and returns the pruning index along the information for plotting.

plot_temp_coalition_pruning() plots the pruning algorithm information calculated by local_pruning().

Event level explanations

local_event() calculates event level explanations of a given sequence with the user-given parameteres and returns the respective event-level explanations.

plot_event_heatmap() plots the event-level explanations calculated by local_event().

Feature level explanations

local_feat() calculates feature level explanations of a given sequence with the user-given parameteres and returns the respective feature-level explanations.

plot_feat_barplot() plots the feature-level explanations calculated by local_feat().

Cell level explanations

local_cell_level() calculates cell level explanations of a given sequence with the respective event- and feature-level explanations and user-given parameteres, returing the respective cell-level explanations.

plot_cell_level() plots the feature-level explanations calculated by local_cell_level().

Local Report

local_report() calculates TimeSHAP local explanations for a given sequence and plots them.

Global Explanations

Global pruning statistics

prune_all() performs the pruning algorithm on multiple given sequences.

pruning_statistics() calculates the pruning statistics for several user-given pruning tolerances using the pruning data calculated by prune_all(), returning a pandas.DataFrame with the statistics.

Global event level explanations

event_explain_all() calculates TimeSHAP event level explanations for multiple instances given user defined parameters.

plot_global_event() plots the global event-level explanations calculated by event_explain_all().

Global feature level explanations

feat_explain_all() calculates TimeSHAP feature level explanations for multiple instances given user defined parameters.

plot_global_feat() plots the global feature-level explanations calculated by feat_explain_all().

Global report

global_report() calculates TimeSHAP explanations for multiple instances, aggregating the explanations on two plots and returning them.

Tutorial

In order to demonstrate TimeSHAP interfaces and methods, you can consult AReM.ipynb. In this tutorial we get an open-source dataset, process it, train Pytorch recurrent model with it and use TimeSHAP to explain it, showcasing all previously described methods.

Additionally, we also train a TensorFlow model on the same dataset AReM_TF.ipynb.

Repository Structure

Citing TimeSHAP

@inproceedings{bento2021timeshap,
    author = {Bento, Jo\~{a}o and Saleiro, Pedro and Cruz, Andr\'{e} F. and Figueiredo, M\'{a}rio A.T. and Bizarro, Pedro},
    title = {TimeSHAP: Explaining Recurrent Models through Sequence Perturbations},
    year = {2021},
    isbn = {9781450383325},
    publisher = {Association for Computing Machinery},
    address = {New York, NY, USA},
    url = {https://doi.org/10.1145/3447548.3467166},
    doi = {10.1145/3447548.3467166},
    booktitle = {Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining},
    pages = {2565–2573},
    numpages = {9},
    keywords = {SHAP, Shapley values, TimeSHAP, XAI, RNN, explainability},
    location = {Virtual Event, Singapore},
    series = {KDD '21}
}
Comments
  • Error in running the example notebook (AReM_TF)

    Error in running the example notebook (AReM_TF)

    Nice work! I have been trying to run one of the tutorial notebooks in the repository (i.e., AReM_TF), but I faced an error. The notebook chunk that produces the error is:

    from timeshap.explainer import local_report, local_pruning
    
    pruning_dict = {'tol': 0.025}
    event_dict = {'rs': 42, 'nsamples': 320}
    feature_dict = {'rs': 42, 'nsamples': 320, 'feature_names': model_features, 'plot_features': plot_feats}
    cell_dict = {'rs': 42, 'nsamples': 320, 'top_x_feats': 2, 'top_x_events': 2}
    local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict,cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)
    

    and the produced error is as follows: image image

    It would be great if you could help me with this error.

    opened by aminnayebi 7
  • Unable to install timeshap package

    Unable to install timeshap package

    Hello @feedzaiadmin , @saleiro ,

    I am unable to run command "pip install timeshap" on my ubuntu system. It throws me below error:

    ERROR: Could not find a version that satisfies the requirement timeshap (from versions: none) ERROR: No matching distribution found for timeshap

    Seems to me a compatibility issue. I have tried with python versions - 3.6,3.8. Is there any specific version which supports it? Will be waiting for response.

    Thanks

    opened by vishants98 5
  • How to adapt the transformation function to account for variable sequence length?

    How to adapt the transformation function to account for variable sequence length?

    I am trying to use TimeSHAP on my use case. Per my understanding, in AReM example, the way you transform the data using the df_to_numpy function is to make a prediction for the last value of the sequence – see the screen below:

    image In the case of AReM tutorial data, the predictions are based on the whole sequence - all rows (rows ID 1-10) are being used for sequence ID 1 (light blue color) and the predictions are made for the Timestamp 10 (dark blue color; rows id 10). Later the light orange color is used (Row IDs 11-20) to predict a label marked as dark orange color (Row ID 20).

    In the case of my use case, the model predicts on a rolling-window basis and I would need predictions for every row (not only for a sequence). See the screen and explanation below. image Let's say my rolling window is 6 and Row IDs 1-6 (light green) are used to predict row 7 (dark green), later Row IDs 2-7 (light grey) are being used to predict Row ID 8 (dark grey), etc. When a new Sequence starts, we repeat the process, so we take Row IDs 11-16 and predict Row ID 17, etc. For my use case, it's important to evaluate the predictions for every Row ID, not only for the whole sequence.

    The problem which I am facing is that when I try to run the function get_avg_score_with_avg_event on the data defined as in the picture above I am getting the following error: image

    The way my data is transformed from 2D into 3D format is defined by the function below: image

    My question is whether it’s possible to make TimeSHAP work for the data which is transformed in a way described in my use case? When I use the transformation which is defined in your function df_to_numpy, I am not getting an error, however, it is not adapted to my use case.

    opened by grzechowiak 4
  • Issue in reproducing TimeSHAP Tutorial - TensorFlow - AReM dataset

    Issue in reproducing TimeSHAP Tutorial - TensorFlow - AReM dataset

    Foremost, thanks for the library, really great job!

    I am trying to reproduce your TimeSHAP Tutorial for TF and I am having an issue in the section for Global Explanations when running the global_report() function - screen below:

    image

    The error refers to encoding the \u2264 character which is a sign: <=. I was trying to solve that myself by modifying the pruning.py function according to the error by adding encoding="utf-8" in line 326 with open(file_path, 'a', newline='') as file:, however it didn't solve the problem. Any advice is very welcome!

    Also, for consistency, I want to mention that I had a problem loading the data - showing the error screens below. I was able to solve the problem only by deleting 2 datasets: cycling/dataset9.csv and cycling/dataset14csv and the rest of the code worked.

    1/2 image 2/2 image

    opened by grzechowiak 4
  • Timeshap for regression

    Timeshap for regression

    I am working on a time series forecasting problem using LSTM. Can I use timeshap for such a regression problem? Do you by chance have a demo for regression?

    Thanks

    opened by mgorjis 3
  • Plot Coalition Pruning is not working

    Plot Coalition Pruning is not working

    Hi, I am doing a project as part of a Master Thesis: I was testing your introductory notebook with the Tensorflow implementation, but this is not working because of an error raised by one of yours plot functions. The funny thing is that your other notebook, i.e., the one with the torch model, is working fine. I managed to find the issue and it was raised by an inner method called solve_negatives_method in src/timeshap/plot/pruning.py which would raise an InvalidKeyError caused by this specific row:

        df.at[corresponding_row.index, 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
    

    As far I understood, we are passing a list, usually made of only one value, to the pandas.DataFrame.at method which requires to pass an integer parameter and a column identifier. As such, one solution that works now could be:

        df.at[corresponding_row.index[0], 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
    

    I want to thank you very much for the work done, and I will be happy to discuss further result of my thesis with you. :) Hoping for the best and looking forward to hear from you, Eric

    opened by Erhtric 2
  • TimeSHAP for text?

    TimeSHAP for text?

    I'm working with a 1-layer GRU for text classification that takes BERT embeddings at the input. Each input sequence is of the shape (sequence length, bert-embedding-dimension). I'm looking for word level attribution scores for each sequence's prediction. Currently with the captum integrated gradients and occlusion explainers, I get attribution scores that are almost always the last few words of each sequence. This seems like it's stemming from the directional processing of GRU - any thoughts?

    Do you think TimeSHAP would be applicable for my use case? I suppose I could consider each word as an event and each embedding dimension as a feature, then I could use the event level local explanations from the library? However, note that in my case, the events (i.e words) from the beginning of the sequence could be more important than those at the end of the sequence (i.e. most recent ones) - this violates the assumptions you use for your approximation (i.e pruning), so perhaps it's not applicable to text?

    opened by itsmemala 2
  • Intuition around pruning/baseline selection

    Intuition around pruning/baseline selection

    While calculating a global report, I'm currently running into errors that "Score difference between baseline and instance is too low < 0.1...Consider choosing another baseline." My baselines have been the average_event and average_sequence. I also notice that occasionally the values in the error change with different pruning tolerance, but not in a consistent way. Do you have advice for dealing with this? Thanks.

    opened by xydisla 2
  • CNN model

    CNN model

    I currently have a CNN model, and previously had to do some strange hacking to get time series importance values. Your package now shines a better light on this issue. For multivariate forecasts CNNs still fair well, sometimes better than RNNs, from what I can see there is no reason why using CNNs won't working using your software. Let me know if I have that wrong.

    opened by firmai 2
  • How to speed up TIMESHAP computation

    How to speed up TIMESHAP computation

    Hi all!

    The package itself is really interesting and intuitive to use. But I want to speed up TIMESHAP computation, Can I use gpu to calculate shapley values? I used a device parameter in TorchModelWrapper, but efficiency of GPU is too low to accelerate TIMESHAP computation. Any suggestion would be appreciated.

    opened by Changshu135 2
  • Error when executing local_report on TF example: InvalidIndexError: Int64Index([1], dtype='int64')

    Error when executing local_report on TF example: InvalidIndexError: Int64Index([1], dtype='int64')

    Hi all! When executing your TF example notebook without any changes, it fails at this line: local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict, cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)

    with the following error: InvalidIndexError: Int64Index([1], dtype='int64')

    Stack trace:

    TypeError                                 Traceback (most recent call last)
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:3621, in Index.get_loc(self, key, method, tolerance)
       3620 try:
    -> 3621     return self._engine.get_loc(casted_key)
       3622 except KeyError as err:
    
    File pandas/_libs/index.pyx:136, in pandas._libs.index.IndexEngine.get_loc()
    
    File pandas/_libs/index.pyx:142, in pandas._libs.index.IndexEngine.get_loc()
    
    TypeError: 'Int64Index([1], dtype='int64')' is an invalid key
    
    During handling of the above exception, another exception occurred:
    
    InvalidIndexError                         Traceback (most recent call last)
    Input In [27], in <cell line: 7>()
          5 feature_dict = {'rs': 42, 'nsamples': 32000, 'feature_names': model_features, 'plot_features': plot_feats}
          6 cell_dict = {'rs': 42, 'nsamples': 32000, 'top_x_feats': 2, 'top_x_events': 2}
    ----> 7 local_report(f, pos_x_data, pruning_dict, event_dict, feature_dict, cell_dict=cell_dict, entity_uuid=positive_sequence_id, entity_col='all_id', baseline=average_event)
    
    File ~/temp/timeshap/src/timeshap/explainer/local_methods.py:139, in local_report(f, data, pruning_dict, event_dict, feature_dict, cell_dict, entity_uuid, entity_col, time_col, model_features, baseline, verbose)
        137 pruning_idx = data.shape[1] + coal_prun_idx
        138 plot_lim = max(abs(coal_prun_idx)+10, 40)
    --> 139 pruning_plot = plot_temp_coalition_pruning(coal_plot_data, coal_prun_idx, plot_lim)
        141 event_data = local_event(f, data, event_dict, entity_uuid, entity_col, baseline, pruning_idx)
        142 event_plot = plot_event_heatmap(event_data)
    
    File ~/temp/timeshap/src/timeshap/plot/pruning.py:53, in plot_temp_coalition_pruning(df, pruned_idx, plot_limit, solve_negatives)
         51 df = df[df['t (event index)'] >= -plot_limit]
         52 if solve_negatives:
    ---> 53     df = solve_negatives_method(df)
         55 base = (alt.Chart(df).encode(
         56     x=alt.X("t (event index)", axis=alt.Axis(title='t (event index)', labelFontSize=15,
         57                           titleFontSize=15)),
       (...)
         70 )
         71 )
         73 area_chart = base.mark_area(opacity=0.5)
    
    File ~/temp/timeshap/src/timeshap/plot/pruning.py:47, in plot_temp_coalition_pruning.<locals>.solve_negatives_method(df)
         45 for idx, row in negative_values.iterrows():
         46     corresponding_row = df[np.logical_and(df['t (event index)'] == row['t (event index)'], ~(df['Coalition'] == row['Coalition']))]
    ---> 47     df.at[corresponding_row.index, 'Shapley Value'] = corresponding_row['Shapley Value'].values[0] + row['Shapley Value']
         48     df.at[idx, 'Shapley Value'] = 0
         49 return df
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexing.py:2281, in _AtIndexer.__setitem__(self, key, value)
       2278     self.obj.loc[key] = value
       2279     return
    -> 2281 return super().__setitem__(key, value)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexing.py:2236, in _ScalarAccessIndexer.__setitem__(self, key, value)
       2233 if len(key) != self.ndim:
       2234     raise ValueError("Not enough indexers for scalar access (setting)!")
    -> 2236 self.obj._set_value(*key, value=value, takeable=self._takeable)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/frame.py:3869, in DataFrame._set_value(self, index, col, value, takeable)
       3867 else:
       3868     series = self._get_item_cache(col)
    -> 3869     loc = self.index.get_loc(index)
       3871 # setitem_inplace will do validation that may raise TypeError
       3872 #  or ValueError
       3873 series._mgr.setitem_inplace(loc, value)
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:3628, in Index.get_loc(self, key, method, tolerance)
       3623         raise KeyError(key) from err
       3624     except TypeError:
       3625         # If we have a listlike key, _check_indexing_error will raise
       3626         #  InvalidIndexError. Otherwise we fall through and re-raise
       3627         #  the TypeError.
    -> 3628         self._check_indexing_error(key)
       3629         raise
       3631 # GH#42269
    
    File ~/opt/anaconda3/envs/timeshap/lib/python3.10/site-packages/pandas/core/indexes/base.py:5637, in Index._check_indexing_error(self, key)
       5633 def _check_indexing_error(self, key):
       5634     if not is_scalar(key):
       5635         # if key is not a scalar, directly raise an error (the code below
       5636         # would convert to numpy arrays and raise later any way) - GH29926
    -> 5637         raise InvalidIndexError(key)
    
    InvalidIndexError: Int64Index([1], dtype='int64')
    
    opened by JulianKlug 2
Owner
Feedzai
Feedzai
Old Photo Restoration (Official PyTorch Implementation)

Bringing Old Photo Back to Life (CVPR 2020 oral)

Microsoft 11.3k Dec 30, 2022
FACIAL: Synthesizing Dynamic Talking Face With Implicit Attribute Learning. ICCV, 2021.

FACIAL: Synthesizing Dynamic Talking Face with Implicit Attribute Learning PyTorch implementation for the paper: FACIAL: Synthesizing Dynamic Talking

226 Jan 08, 2023
Some bravo or inspiring research works on the topic of curriculum learning.

Towards Scalable Unpaired Virtual Try-On via Patch-Routed Spatially-Adaptive GAN Official code for NeurIPS 2021 paper "Towards Scalable Unpaired Virtu

131 Jan 07, 2023
This repository contains PyTorch code for Robust Vision Transformers.

This repository contains PyTorch code for Robust Vision Transformers.

117 Dec 07, 2022
PyTorch implementation of convolutional neural networks-based text-to-speech synthesis models

Deepvoice3_pytorch PyTorch implementation of convolutional networks-based text-to-speech synthesis models: arXiv:1710.07654: Deep Voice 3: Scaling Tex

Ryuichi Yamamoto 1.8k Jan 08, 2023
Custom studies about block sparse attention.

Block Sparse Attention 研究总结 本人近半年来对Block Sparse Attention(块稀疏注意力)的研究总结(持续更新中)。按时间顺序,主要分为如下三部分: PyTorch 自定义 CUDA 算子——以矩阵乘法为例 基于 Triton 的 Block Sparse A

Chen Kai 2 Jan 09, 2022
Deep Learning Interviews book: Hundreds of fully solved job interview questions from a wide range of key topics in AI.

This book was written for you: an aspiring data scientist with a quantitative background, facing down the gauntlet of the interview process in an increasingly competitive field. For most of you, the

4.1k Dec 28, 2022
某学校选课系统GIF验证码数据集 + Baseline模型 + 上下游相关工具

elective-dataset-2021spring 某学校2021春季选课系统GIF验证码数据集(29338张) + 准确率98.4%的Baseline模型 + 上下游相关工具。 数据集采用 知识共享署名-非商业性使用 4.0 国际许可协议 进行许可。 Baseline模型和上下游相关工具采用

xmcp 27 Sep 17, 2021
Context-Sensitive Misspelling Correction of Clinical Text via Conditional Independence, CHIL 2022

cim-misspelling Pytorch implementation of Context-Sensitive Spelling Correction of Clinical Text via Conditional Independence, CHIL 2022. This model (

Juyong Kim 11 Dec 19, 2022
Everything you need to know about NumPy( Creating Arrays, Indexing, Math,Statistics,Reshaping).

Everything you need to know about NumPy( Creating Arrays, Indexing, Math,Statistics,Reshaping).

1 Feb 14, 2022
Adversarial Self-Defense for Cycle-Consistent GANs

Adversarial Self-Defense for Cycle-Consistent GANs This is the official implementation of the CycleGAN robust to self-adversarial attacks used in pape

Dina Bashkirova 10 Oct 10, 2022
A simple pytorch pipeline for semantic segmentation.

SegmentationPipeline -- Pytorch A simple pytorch pipeline for semantic segmentation. Requirements : torch=1.9.0 tqdm albumentations=1.0.3 opencv-pyt

petite7 4 Feb 22, 2022
Face-Recognition-based-Attendance-System - An implementation of Attendance System in python.

Face-Recognition-based-Attendance-System A real time implementation of Attendance System in python. Pre-requisites To understand the implentation of F

Muhammad Zain Ul Haque 1 Dec 31, 2021
CVPR2021 Content-Aware GAN Compression

Content-Aware GAN Compression [ArXiv] Paper accepted to CVPR2021. @inproceedings{liu2021content, title = {Content-Aware GAN Compression}, auth

52 Nov 06, 2022
A project for developing transformer-based models for clinical relation extraction

Clinical Relation Extration with Transformers Aim This package is developed for researchers easily to use state-of-the-art transformers models for ext

uf-hobi-informatics-lab 101 Dec 19, 2022
Deep Sea Treasure Environment for Multi-Objective Optimization Research

DeepSeaTreasure Environment Installation In order to get started with this environment, you can install it using the following command: python3 -m pip

imec IDLab 6 Nov 14, 2022
Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning, CVPR 2021

Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning By Zhenda Xie*, Yutong Lin*, Zheng Zhang, Yue Ca

Zhenda Xie 293 Dec 20, 2022
custom pytorch implementation of MoCo v3

MoCov3-pytorch custom implementation of MoCov3 [arxiv]. I made minor modifications based on the official MoCo repository [github]. No ViT part code an

39 Nov 14, 2022
Look Who’s Talking: Active Speaker Detection in the Wild

Look Who's Talking: Active Speaker Detection in the Wild Dependencies pip install -r requirements.txt In addition to the Python dependencies, ffmpeg

Clova AI Research 60 Dec 08, 2022