JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Overview

Optimal Model Design for Reinforcement Learning

This repository contains JAX code for the paper

Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation

by Evgenii Nikishin, Romina Abachi, Rishabh Agarwal, and Pierre-Luc Bacon.

Summary

Model based reinforcement learning typically trains the dynamics and reward functions by minimizing the error of predictions. The error is only a proxy to maximizing the sum of rewards, the ultimate goal of the agent, leading to the objective mismatch. We propose an end-to-end algorithm called Optimal Model Design (OMD) that optimizes the returns directly for model learning. OMD leverages the implicit function theorem to optimize the model parameters and forms the following computational graph:

Installation

We assume that you use Python 3. To install the necessary dependencies, run the following commands:

1. virtualenv ~/env_omd
2. source ~/env_omd/bin/activate
3. pip install -r requirements.txt

To use JAX with GPU, follow the official instructions. To install MuJoCo, check the instructions.

Run

For historical reasons, the code is divided into 3 parts.

Tabular

All results for the tabular experiments could be reproduced by running the tabular.ipynb notebook.

To open the notebook in Google Colab, use this link.

CartPole

To train the OMD agent on CartPole, use the following commands:

cd cartpole
python train.py --agent_type omd

We also provide the implementation of the corresponding MLE and VEP baselines. To train the agents, change the --agent_type flag to mle or vep.

MuJoCo

To train the OMD agent on MuJoCo HalfCheetah-v2, use the following commands:

cd mujoco
python train.py --config.algo=omd

To train the MLE baseline, change the --config.algo flag to mle.

Acknowledgements

Owner
Evgenii Nikishin
Evgenii Nikishin
[CVPR 2022] Semi-Supervised Semantic Segmentation Using Unreliable Pseudo-Labels

Using Unreliable Pseudo Labels Official PyTorch implementation of Semi-Supervised Semantic Segmentation Using Unreliable Pseudo Labels, CVPR 2022. Ple

Haochen Wang 268 Dec 24, 2022
FaceOcc: A Diverse, High-quality Face Occlusion Dataset for Human Face Extraction

FaceExtraction FaceOcc: A Diverse, High-quality Face Occlusion Dataset for Human Face Extraction Occlusions often occur in face images in the wild, tr

16 Dec 14, 2022
CRLT: A Unified Contrastive Learning Toolkit for Unsupervised Text Representation Learning

CRLT: A Unified Contrastive Learning Toolkit for Unsupervised Text Representation Learning This repository contains the code and relevant instructions

XiaoMing 5 Aug 19, 2022
FactSeg: Foreground Activation Driven Small Object Semantic Segmentation in Large-Scale Remote Sensing Imagery (TGRS)

FactSeg: Foreground Activation Driven Small Object Semantic Segmentation in Large-Scale Remote Sensing Imagery by Ailong Ma, Junjue Wang*, Yanfei Zhon

Kingdrone 43 Jan 05, 2023
A more easy-to-use implementation of KPConv based on PyTorch.

A more easy-to-use implementation of KPConv This repo contains a more easy-to-use implementation of KPConv based on PyTorch. Introduction KPConv is a

Zheng Qin 36 Dec 29, 2022
Official implementation for paper Knowledge Bridging for Empathetic Dialogue Generation (AAAI 2021).

Knowledge Bridging for Empathetic Dialogue Generation This is the official implementation for paper Knowledge Bridging for Empathetic Dialogue Generat

Qintong Li 50 Dec 20, 2022
A library for preparing, training, and evaluating scalable deep learning hybrid recommender systems using PyTorch.

collie_recs Collie is a library for preparing, training, and evaluating implicit deep learning hybrid recommender systems, named after the Border Coll

ShopRunner 97 Jan 03, 2023
Low Complexity Channel estimation with Neural Network Solutions

Interpolation-ResNet Invited paper for WSA 2021, called 'Low Complexity Channel estimation with Neural Network Solutions'. Low complexity residual con

Dianxin 10 Dec 10, 2022
Joint learning of images and text via maximization of mutual information

mutual_info_img_txt Joint learning of images and text via maximization of mutual information. This repository incorporates the algorithms presented in

Ruizhi Liao 10 Dec 22, 2022
This repository contains a set of codes to run (i.e., train, perform inference with, evaluate) a diarization method called EEND-vector-clustering.

EEND-vector clustering The EEND-vector clustering (End-to-End-Neural-Diarization-vector clustering) is a speaker diarization framework that integrates

45 Dec 26, 2022
Backend code to use MCPI's python API to make infinite worlds with custom generation

inf-mcpi Backend code to use MCPI's python API to make infinite worlds with custom generation Does not save player-placed blocks! Generation is still

5 Oct 04, 2022
Chatbot in 200 lines of code using TensorLayer

Seq2Seq Chatbot This is a 200 lines implementation of Twitter/Cornell-Movie Chatbot, please read the following references before you read the code: Pr

TensorLayer Community 820 Dec 17, 2022
[CoRL 2021] A robotics benchmark for cross-embodiment imitation.

x-magical x-magical is a benchmark extension of MAGICAL specifically geared towards cross-embodiment imitation. The tasks still provide the Demo/Test

Kevin Zakka 36 Nov 26, 2022
Python package for visualizing the loss landscape of parameterized quantum algorithms.

orqviz A Python package for easily visualizing the loss landscape of Variational Quantum Algorithms by Zapata Computing Inc. orqviz provides a collect

Zapata Computing, Inc. 75 Dec 30, 2022
Hand Gesture Volume Control is AIML based project which uses image processing to control the volume of your Computer.

Hand Gesture Volume Control Modules There are basically three modules Handtracking Program Handtracking Module Volume Control Program Handtracking Pro

VITTAL 1 Jan 12, 2022
AI Face Mesh: This is a simple face mesh detection program based on Artificial intelligence.

AI Face Mesh: This is a simple face mesh detection program based on Artificial Intelligence which made with Python. It's able to detect 468 different

Md. Rakibul Islam 1 Jan 13, 2022
YourTTS: Towards Zero-Shot Multi-Speaker TTS and Zero-Shot Voice Conversion for everyone

YourTTS: Towards Zero-Shot Multi-Speaker TTS and Zero-Shot Voice Conversion for everyone In our recent paper we propose the YourTTS model. YourTTS bri

Edresson Casanova 390 Dec 29, 2022
The first dataset on shadow generation for the foreground object in real-world scenes.

Object-Shadow-Generation-Dataset-DESOBA Object Shadow Generation is to deal with the shadow inconsistency between the foreground object and the backgr

BCMI 105 Dec 30, 2022
A Game-Theoretic Perspective on Risk-Sensitive Reinforcement Learning

Officile code repository for "A Game-Theoretic Perspective on Risk-Sensitive Reinforcement Learning"

Mathieu Godbout 1 Nov 19, 2021
SFD implement with pytorch

S³FD: Single Shot Scale-invariant Face Detector A PyTorch Implementation of Single Shot Scale-invariant Face Detector Description Meanwhile train hand

Jun Li 251 Dec 22, 2022