Decision Transformer: A brand new Offline RL Pattern

Overview

DecisionTransformer_StepbyStep

Intro

Decision Transformer: A brand new Offline RL Pattern.

这是关于NeurIPS 2021 热门论文Decision Transformer的复现。

👍 原文地址: Decision Transformer: Reinforcement Learning via Sequence Modeling

👍 官方的Git仓库: decision-transformer(official)

Decision Transformer

Decision Transformer属于Offline RL,所谓Offline RL,即从次优数据中学习策略来分配Agent,即从固定、有限的经验中产生最大有效的行为。

👀️ Motivation

DT将RL看成一个序列建模问题(Sequence Modeling Problem ),不用传统RL方法,而使用网络直接输出动作进行决策。传统RL方法存在一些问题,比如估计未来Return过程中Bootstrapping过程会导致Overestimate; 马尔可夫假设;

DT借助了Transformer的强大表征能力和时序建模能力。

  • Decision Transformer的表现达到甚至超过了目前最好的基于dynamic programming的主流方法;
  • 在一些需要long-term credit assignment的task【例如sparse reward或者delayed reward等】,Decision Transformer的表现远超过了最好的主流方法.

🚀️ DT的核心思想

image.png

Decision Transformer的核心思想; States、Actions、Returns被Fed into Modality-Specific的线性Embedding;并添加了带有时间步信息的positional episodic timestep; 这些Tokens被输入一个GPT架构,使用a causal self-attention mask来预测actions。

🎉️ DT的优势

  1. 无需Markov假设;
  2. 没有使用一个可学习的Value Function作为Training Target;
  3. 利用Transformer的特性,绕过长期信用分配进行“自举bootstrapping”的需要,避免了时序差分学习的“短视”行为;
  4. 可以通过self-attention直接执行信度分配。这与缓慢传播奖励并容易产生干扰信号的 Bellman Backup 相反,可以使 Transformer 在奖励稀少或分散注意力的情况下仍然有效地工作.

Dependencies

1. D4RL ( Dataset for Deep Data-Driven Reinforcement Learning )

2. MUJOCO 210

# 安装之前先安装absl-py和matplotlib 
pip install absl-py 
pip install matplotlib 

"""
git clone https://github.com/rail-berkeley/d4rl.git
cd d4rl
pip install -e . # 这种方法不好使 !! 
"""

#首先在https://github.com/deepmind/dm_control这个库git clone
# cd
pip install -r requirement.txt 
# 然后 
pip install matplotlib 
# 然后 https://github.com/takuseno/d3rlpy 
pip install d3rlpy 
# 然后安装mujoco 210  
# 直接安装,然后添加环境变量 
# 装完之后进d4rl文件夹下
python setup.py install 
# 成功安装 d4rl 1.1 

3. GPT-2


pip install transformers

Experiments

Group1: Decision Transformer — Hopper-v3-Medium-Dataset

参数Config

class Config:
    env = "hopper"
    dataset = "medium"
    mode = "normal" # "delayed" : all rewards moved to end of trajectory
    device = 'cuda'
    log_dir = 'TB_log/'
    record_algo = 'DT_Hopper_v1'
    test_cycles = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

    # 模型
    model_type = "DT"
    activation_function = 'relu'

    # Scalar
    max_length = 20 # max_len # K
    pct_traj = 1.
    batch_size = 64
    embed_dim = 128
    n_layer = 3
    n_head = 1
    dropout = 0.1
    lr = 1e-4
    wd = 1e-4
    warmup_steps = 1000
    num_eval_episodes = 100
    max_iters = 50
    num_steps_per_iter = 1000

    # Bool
    log_to_tb = True

效果

image.png

Owner
Irving
Irving
Code for Understanding Pooling in Graph Neural Networks

Select, Reduce, Connect This repository contains the code used for the experiments of: "Understanding Pooling in Graph Neural Networks" Setup Install

Daniele Grattarola 37 Dec 13, 2022
Apply AnimeGAN-v2 across frames of a video clip

title emoji colorFrom colorTo sdk app_file pinned AnimeGAN-v2 For Videos 🔥 blue red gradio app.py false AnimeGAN-v2 For Videos Apply AnimeGAN-v2 acro

Nathan Raw 36 Oct 18, 2022
using yolox+deepsort for object-tracker

YOLOX_deepsort_tracker yolox+deepsort实现目标跟踪 最新的yolox尝尝鲜~~(yolox正处在频繁更新阶段,因此直接链接yolox仓库作为子模块) Install Clone the repository recursively: git clone --rec

245 Dec 26, 2022
Export CenterPoint PonintPillars ONNX Model For TensorRT

CenterPoint-PonintPillars Pytroch model convert to ONNX and TensorRT Welcome to CenterPoint! This project is fork from tianweiy/CenterPoint. I impleme

CarkusL 149 Dec 13, 2022
Unofficial implementation (replicates paper results!) of MINER: Multiscale Implicit Neural Representations in pytorch-lightning

MINER_pl Unofficial implementation of MINER: Multiscale Implicit Neural Representations in pytorch-lightning. 📖 Ref readings Laplacian pyramid explan

AI葵 51 Nov 28, 2022
Final project for machine learning (CSC 590). Detection of hepatitis C and progression through blood samples.

Hepatitis C Blood Based Detection Final project for machine learning (CSC 590). Dataset from Kaggle. Using data from previous hepatitis C blood panels

Jennefer Maldonado 1 Dec 28, 2021
Neural Caption Generator with Attention

Neural Caption Generator with Attention Tensorflow implementation of "Show

Taeksoo Kim 510 Nov 30, 2022
Pytorch Lightning Distributed Accelerators using Ray

Distributed PyTorch Lightning Training on Ray This library adds new PyTorch Lightning accelerators for distributed training using the Ray distributed

166 Dec 27, 2022
This is the solution for 2nd rank in Kaggle competition: Feedback Prize - Evaluating Student Writing.

Feedback Prize - Evaluating Student Writing This is the solution for 2nd rank in Kaggle competition: Feedback Prize - Evaluating Student Writing. The

Udbhav Bamba 41 Dec 14, 2022
A modular domain adaptation library written in PyTorch.

A modular domain adaptation library written in PyTorch.

Kevin Musgrave 225 Dec 29, 2022
Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes

Lepard: Learning Partial point cloud matching in Rigid and Deformable scenes [Paper] Method overview 4DMatch Benchmark 4DMatch is a benchmark for matc

103 Jan 06, 2023
LBK 20 Dec 02, 2022
GNN-based Recommendation Benchma

GRecX A Fair Benchmark for GNN-based Recommendation Preliminary Comparison DiffNet-Yelp dataset (featureless) Algo 73 Oct 17, 2022

A series of Jupyter notebooks with Chinese comment that walk you through the fundamentals of Machine Learning and Deep Learning in python using Scikit-Learn and TensorFlow.

Hands-on-Machine-Learning 目的 这份笔记旨在帮助中文学习者以一种较快较系统的方式入门机器学习, 是在学习Hands-on Machine Learning with Scikit-Learn and TensorFlow这本书的 时候做的个人笔记: 此项目的可取之处 原书的

Baymax 1.5k Dec 21, 2022
Individual Treatment Effect Estimation

CAPE Individual Treatment Effect Estimation Run CAPE python train_causal.py --loop 10 -m cape_cau -d NI --i_t 1 Run a baseline model python train_cau

S. Deng 4 Sep 02, 2022
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty

AugMix Introduction We propose AugMix, a data processing technique that mixes augmented images and enforces consistent embeddings of the augmented ima

Google Research 876 Dec 17, 2022
Course materials for Fall 2021 "CIS6930 Topics in Computing for Data Science" at New College of Florida

Fall 2021 CIS6930 Topics in Computing for Data Science This repository hosts course materials used for a 13-week course "CIS6930 Topics in Computing f

Yoshi Suhara 101 Nov 30, 2022
[ECCV'20] Convolutional Occupancy Networks

Convolutional Occupancy Networks Paper | Supplementary | Video | Teaser Video | Project Page | Blog Post This repository contains the implementation o

622 Dec 30, 2022
Code for the paper "Location-aware Single Image Reflection Removal"

Location-aware Single Image Reflection Removal The shown images are provided by the datasets from IBCLN, ERRNet, SIR2 and the Internet images. The cod

72 Dec 08, 2022
The official repository for Deep Image Matting with Flexible Guidance Input

FGI-Matting The official repository for Deep Image Matting with Flexible Guidance Input. Paper: https://arxiv.org/abs/2110.10898 Requirements easydict

Hang Cheng 51 Nov 10, 2022