Training a Resilient Q-Network against Observational Interference, Causal Inference Q-Networks

Overview

Obs-Causal-Q-Network

AAAI 2022 - Training a Resilient Q-Network against Observational Interference

Preprint | Slides | Colab Demo | PyTorch

Environment Setup

  • option 1 (from conda .yml under conda 10.2 and python 3.6)
conda env create -f obs-causal-q-conda.yml 
  • option 2 (from a clean python 3.6 and please follow the setup of UnityAgent 3D environment for Banana Navigator )
pip install torch torchvision torchaudio
pip install dowhy
pip install gym

1. Example of Training Causal Inference Q-Network (CIQ) on Cartpole

  • Run Causal Inference Q-Network Training (--network 1 for Treatment Inference Q-network)
python 0-cartpole-main.py --network 1
  • Causal Inference Q-Network Architecture

  • Output Logs
observation space: Box(4,)
action space: Discrete(2)
Timing Atk Ratio: 10%
Using CEQNetwork_1. Number of Params: 41872
 Interference Type: 1  Use baseline:  0 use CGM:  1
With:  10.42 % timing attack
Episode 0   Score: 48.00, Average Score: 48.00, Loss: 1.71
With:  0.0 % timing attack
Episode 20   Score: 15.00, Average Score: 18.71, Loss: 30.56
With:  3.57 % timing attack
Episode 40   Score: 28.00, Average Score: 19.83, Loss: 36.36
With:  8.5 % timing attack
Episode 60   Score: 200.00, Average Score: 43.65, Loss: 263.29
With:  9.0 % timing attack
Episode 80   Score: 200.00, Average Score: 103.53, Loss: 116.35
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 193.4
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 164.2
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 147.8
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 193.4
With:  9.5 % timing attack
Episode 100   Score: 200.00, Average Score: 163.20, Loss: 77.38
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 198.4
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 197.8
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 197.6
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 198.6
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 199.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 186.8
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0

Environment solved in 114 episodes!     Average Score: 195.55
Environment solved in 114 episodes!     Average Score: 195.55 +- 25.07
############# Basic Evaluate #############
Using CEQNetwork_1. Number of Params: 41872
Evaluate Score : 200.0
############# Noise Evaluate #############
Using CEQNetwork_1. Number of Params: 41872
Robust Score : 200.0

2. Example of Training a "Variational" Causal Inference Q-Network on Unity 3D Banana Navigator

  • Run Variational Causal Inference Q-Networks (VCIQs) Training (--network 3 for Causal Variational Inference)
python 1-banana-navigator-main.py --network 3
  • Variational Causal Inference Q-Network Architecture

  • Output Logs
'Academy' started successfully!
Unity Academy name: Academy
        Number of Brains: 1
        Number of External Brains : 1
        Lesson number : 0
        Reset Parameters :

Unity brain name: BananaBrain
        Number of Visual Observations (per agent): 0
        Vector Observation space type: continuous
        Vector Observation space size (per agent): 37
        Number of stacked Vector Observation: 1
        Vector Action space type: discrete
        Vector Action space size (per agent): 4
        Vector Action descriptions: , , , 
Timing Atk Ratio: 10%
Using CEVAE_QNetwork.
Unity Worker id: 10  T: 1  Use baseline:  0  CEVAE:  1
With:  9.67 % timing attack
Episode 0   Score: 0.00, Average Score: 0.00
With:  11.0 % timing attack
Episode 5   Score: 1.00, Average Score: 0.17
With:  11.33 % timing attack
Episode 10   Score: 0.00, Average Score: 0.36
With:  10.33 % timing attack
Episode 15   Score: 0.00, Average Score: 0.56
...
Episode 205   Score: 10.00, Average Score: 9.25
With:  9.33 % timing attack
Episode 210   Score: 9.00, Average Score: 9.70
With:  9.0 % timing attack
Episode 215   Score: 10.00, Average Score: 11.10
With:  8.33 % timing attack
Episode 220   Score: 14.00, Average Score: 10.85
With:  12.33 % timing attack
Episode 225   Score: 19.00, Average Score: 11.70
With:  11.0 % timing attack
Episode 230   Score: 18.00, Average Score: 12.10
With:  7.67 % timing attack
Episode 235   Score: 21.00, Average Score: 11.60
With:  9.67 % timing attack
Episode 240   Score: 16.00, Average Score: 12.05

Environment solved in 242 episodes!     Average Score: 12.50
Environment solved in 242 episodes!     Average Score: 12.50 +- 4.87
############# Basic Evaluate #############
Using CEVAE_QNetwork.
Evaluate Score : 12.6
############# Noise Evaluate #############
Using CEVAE_QNetwork.
Robust Score : 12.5

Reference

This fun work was initialzed when Danny and I first read the Causal Variational Model between 2018 to 2019 with the helps from Dr. Yi Ouyang and Dr. Pin-Yu Chen.

Please consider to reference the paper if you find this work helpful or relative to your research.

@article{yang2021causal,
  title={Causal Inference Q-Network: Toward Resilient Reinforcement Learning},
  author={Yang, Chao-Han Huck and Hung, I and Danny, Te and Ouyang, Yi and Chen, Pin-Yu},
  journal={arXiv preprint arXiv:2102.09677},
  year={2021}
}
Owner
Speech, Privacy, Robust RL, and Causal Inference.
Apply Graph Self-Supervised Learning methods to graph-level task(TUDataset, MolculeNet Datset)

Graphlevel-SSL Overview Apply Graph Self-Supervised Learning methods to graph-level task(TUDataset, MolculeNet Dataset). It is unified framework to co

JunSeok 8 Oct 15, 2021
This is the second place solution for : UmojaHack Africa 2022: African Snake Antivenom Binding Challenge

UmojaHack-Africa-2022-African-Snake-Antivenom-Binding-Challenge This is the second place solution for : UmojaHack Africa 2022: African Snake Antivenom

Mami Mokhtar 10 Dec 03, 2022
Technical experimentations to beat the stock market using deep learning :chart_with_upwards_trend:

DeepStock Technical experimentations to beat the stock market using deep learning. Experimentations Deep Learning Stock Prediction with Daily News Hea

Keon 449 Dec 29, 2022
Revealing and Protecting Labels in Distributed Training

Revealing and Protecting Labels in Distributed Training

Google Interns 0 Nov 09, 2022
[CVPR 2021] Pytorch implementation of Hijack-GAN: Unintended-Use of Pretrained, Black-Box GANs

Hijack-GAN: Unintended-Use of Pretrained, Black-Box GANs In this work, we propose a framework HijackGAN, which enables non-linear latent space travers

Hui-Po Wang 46 Sep 05, 2022
Code for Low-Cost Algorithmic Recourse for Users With Uncertain Cost Functions

EMS-COLS-recourse Initial Code for Low-Cost Algorithmic Recourse for Users With Uncertain Cost Functions Folder structure: data folder contains raw an

Prateek Yadav 1 Nov 25, 2022
OBBDetection is a oriented object detection library, which is based on MMdetection.

OBBDetection news: We are now updating OBBDetection to new vision based on MMdetection v2.10, which has more advanced models and more efficient featur

jbwang1997 401 Jan 02, 2023
NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

5 Nov 03, 2022
DCT-Mask: Discrete Cosine Transform Mask Representation for Instance Segmentation

DCT-Mask: Discrete Cosine Transform Mask Representation for Instance Segmentation This project hosts the code for implementing the DCT-MASK algorithms

Alibaba Cloud 57 Nov 27, 2022
Another pytorch implementation of FCN (Fully Convolutional Networks)

FCN-pytorch-easiest Trying to be the easiest FCN pytorch implementation and just in a get and use fashion Here I use a handbag semantic segmentation f

Y. Dong 158 Dec 21, 2022
Code for the paper "Training GANs with Stronger Augmentations via Contrastive Discriminator" (ICLR 2021)

Training GANs with Stronger Augmentations via Contrastive Discriminator (ICLR 2021) This repository contains the code for reproducing the paper: Train

Jongheon Jeong 174 Dec 29, 2022
Boost learning for GNNs from the graph structure under challenging heterophily settings. (NeurIPS'20)

Beyond Homophily in Graph Neural Networks: Current Limitations and Effective Designs Jiong Zhu, Yujun Yan, Lingxiao Zhao, Mark Heimann, Leman Akoglu,

GEMS Lab: Graph Exploration & Mining at Scale, University of Michigan 70 Dec 18, 2022
Convert openmmlab (not only mmdetection) series model to tensorrt

MMDet to TensorRT This project aims to convert the mmdetection model to TensorRT model end2end. Focus on object detection for now. Mask support is exp

JinTian 4 Dec 17, 2021
No Code AI/ML platform

NoCodeAIML No Code AI/ML platform - Community Edition Video credits: Uday Kiran Typical No Code AI/ML Platform will have features like drag and drop,

Bhagvan Kommadi 5 Jan 28, 2022
Re-implement CycleGAN in Tensorlayer

CycleGAN_Tensorlayer Re-implement CycleGAN in TensorLayer Original CycleGAN Improved CycleGAN with resize-convolution Prerequisites: TensorLayer Tenso

89 Aug 15, 2022
Vpw analyzer - A visual J1850 VPW analyzer written in Python

VPW Analyzer A visual J1850 VPW analyzer written in Python Requires Tkinter, Pan

7 May 01, 2022
Send text to girlfriend in the morning

Girlfriend Text Send text to girlfriend (or really anyone with a phone number) in the morning 1. Configure your settings in utils.py. phone_number = "

Paras Adhikary 199 Oct 25, 2022
ScriptProfilerPy - Module to visualize where your python script is slow

ScriptProfiler helps you track where your code is slow It provides: Code lines t

Lucas BLP 3 Jun 02, 2022
Learnable Boundary Guided Adversarial Training (ICCV2021)

Learnable Boundary Guided Adversarial Training This repository contains the implementation code for the ICCV2021 paper: Learnable Boundary Guided Adve

DV Lab 27 Sep 25, 2022
Highly comparative time-series analysis

〰️ hctsa 〰️ : highly comparative time-series analysis hctsa is a software package for running highly comparative time-series analysis using Matlab (fu

Ben Fulcher 569 Dec 21, 2022