Code for "Unsupervised State Representation Learning in Atari"

Overview

Unsupervised State Representation Learning in Atari

Ankesh Anand*, Evan Racah*, Sherjil Ozair*, Yoshua Bengio, Marc-Alexandre Cรดtรฉ, R Devon Hjelm

This repo provides code for the benchmark and techniques introduced in the paper Unsupervised State Representation Learning in Atari

Install

AtariARI Wrapper

You can do a minimal install to get just the AtariARI (Atari Annotated RAM Interface) wrapper by doing:

pip install 'gym[atari]'
pip install git+git://github.com/mila-iqia/atari-representation-learning.git

This just requires gym[atari] and it gives you the ability to play around with the AtariARI wrapper. If you want to use the code for training representation learning methods and probing them, you will need a full installation:

Full installation (AtariARI Wrapper + Training & Probing Code)

# PyTorch and scikit learn
conda install pytorch torchvision -c pytorch
conda install scikit-learn

# Baselines for Atari preprocessing
# Tensorflow is a dependency, but you don't need to install the GPU version
conda install tensorflow
pip install git+git://github.com/openai/baselines

# pytorch-a2c-ppo-acktr for RL utils
pip install git+git://github.com/ankeshanand/pytorch-a2c-ppo-acktr-gail

# Clone and install our package
pip install -r requirements.txt
pip install git+git://github.com/mila-iqia/atari-representation-learning.git

Usage

Atari Annotated RAM Interface (AtariARI):

AtariARI exposes the ground truth labels for different state variables for each observation. We have made AtariARI available as a Gym wrapper, to use it simply wrap an Atari gym env with AtariARIWrapper.

import gym
from atariari.benchmark.wrapper import AtariARIWrapper
env = AtariARIWrapper(gym.make('MsPacmanNoFrameskip-v4'))
obs = env.reset()
obs, reward, done, info = env.step(1)

Now, info is a dictionary of the form:

{'ale.lives': 3,
 'labels': {'enemy_sue_x': 88,
  'enemy_inky_x': 88,
  'enemy_pinky_x': 88,
  'enemy_blinky_x': 88,
  'enemy_sue_y': 80,
  'enemy_inky_y': 80,
  'enemy_pinky_y': 80,
  'enemy_blinky_y': 50,
  'player_x': 88,
  'player_y': 98,
  'fruit_x': 0,
  'fruit_y': 0,
  'ghosts_count': 3,
  'player_direction': 3,
  'dots_eaten_count': 0,
  'player_score': 0,
  'num_lives': 2}}

Note: In our experiments, we use additional preprocessing for Atari environments mainly following Minh et. al, 2014. See atariari/benchmark/envs.py for more info!

If you want the raw RAM annotations (which parts of ram correspond to each state variable), check out atariari/benchmark/ram_annotations.py

Probing


โš ๏ธ Important โš ๏ธ : The RAM labels are meant for full-sized Atari observations (210 * 160). Probing results won't be accurate if you downsample the observations.

We provide an interface for the included probing tasks.

First, get episodes for train, val and, test:

from atariari.benchmark.episodes import get_episodes

tr_episodes, val_episodes,\
tr_labels, val_labels,\
test_episodes, test_labels = get_episodes(env_name="PitfallNoFrameskip-v4", 
                                     steps=50000, 
                                     collect_mode="random_agent")

Then probe them using ProbeTrainer and your encoder (my_encoder):

from atariari.benchmark.probe import ProbeTrainer

probe_trainer = ProbeTrainer(my_encoder, representation_len=my_encoder.feature_size)
probe_trainer.train(tr_episodes, val_episodes,
                     tr_labels, val_labels,)
final_accuracies, final_f1_scores = probe_trainer.test(test_episodes, test_labels)

To see how we use ProbeTrainer, check out scripts/run_probe.py

Here is an example of my_encoder:

# get your encoder
import torch.nn as nn
import torch
class MyEncoder(nn.Module):
    def __init__(self, input_channels, feature_size):
        super().__init__()
        self.feature_size = feature_size
        self.input_channels = input_channels
        self.final_conv_size = 64 * 9 * 6
        self.cnn = nn.Sequential(
            nn.Conv2d(input_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, stride=1),
            nn.ReLU()
        )
        self.fc = nn.Linear(self.final_conv_size, self.feature_size)

    def forward(self, inputs):
        x = self.cnn(inputs)
        x = x.view(x.size(0), -1)
        return self.fc(x)
        

my_encoder = MyEncoder(input_channels=1,feature_size=256)
# load in weights
my_encoder.load_state_dict(torch.load(open("path/to/my/weights.pt", "rb")))

Spatio-Temporal DeepInfoMax:

src/ contains implementations of several representation learning methods, along with ST-DIM. Here's a sample usage:

python -m scripts.run_probe --method infonce-stdim --env-name {env_name}

where env_name is of the form {game}NoFrameskip-v4, such as PongNoFrameskip-v4

Citation

@article{anand2019unsupervised,
  title={Unsupervised State Representation Learning in Atari},
  author={Anand, Ankesh and Racah, Evan and Ozair, Sherjil and Bengio, Yoshua and C{\^o}t{\'e}, Marc-Alexandre and Hjelm, R Devon},
  journal={arXiv preprint arXiv:1906.08226},
  year={2019}
}
Owner
Mila
Quebec Artificial Intelligence Institute
Mila
Code for NAACL 2021 full paper "Efficient Attentions for Long Document Summarization"

LongDocSum Code for NAACL 2021 paper "Efficient Attentions for Long Document Summarization" This repository contains data and models needed to reprodu

56 Jan 02, 2023
ALBERT-pytorch-implementation - ALBERT pytorch implementation

ALBERT-pytorch-implementation developing... ๋ชจ๋ธ์˜ ๊ฐœ๋…์ดํ•ด๋ฅผ ๋•๊ธฐ ์œ„ํ•œ ๊ตฌํ˜„๋ฌผ๋กœ ํ˜„์žฌ ๋ณ€์ˆ˜๋ช…์„ ์ƒ์„ธํžˆ ์ ์—ˆ๊ณ 

BG Kim 3 Oct 06, 2022
A graph neural network (GNN) model to predict protein-protein interactions (PPI) with no sample features

A graph neural network (GNN) model to predict protein-protein interactions (PPI) with no sample features

2 Jul 25, 2022
Self-supervised learning optimally robust representations for domain generalization.

OptDom: Learning Optimal Representations for Domain Generalization This repository contains the official implementation for Optimal Representations fo

Yangjun Ruan 18 Aug 25, 2022
YOLOv5 Series Multi-backbone, Pruning and quantization Compression Tool Box.

YOLOv5-Compression Update News Requirements ็Žฏๅขƒๅฎ‰่ฃ… pip install -r requirements.txt Evaluation metric Visdrone Model mAP ZhangYuan 719 Jan 02, 2023

[NeurIPS2021] Code Release of K-Net: Towards Unified Image Segmentation

K-Net: Towards Unified Image Segmentation Introduction This is an official release of the paper K-Net:Towards Unified Image Segmentation. K-Net will a

Wenwei Zhang 423 Jan 02, 2023
True per-item rarity for Loot

True-Rarity True per-item rarity for Loot (For Adventurers) and More Loot A.K.A mLoot each out/true_rarity_{item_type}.json file contains probabilitie

Dan R. 3 Jul 26, 2022
Implementing a simplified copy of Shazam application from scratch using MinHashing and LSH.

Building Shazam from scratch In this repository we tried to implement a simplified copy of the Shazam application able to tell you the name of a song

Arturo Ghinassi 0 Nov 17, 2022
DeepStochlog Package For Python

DeepStochLog Installation Installing SWI Prolog DeepStochLog requires SWI Prolog to run. Run the following commands to install: sudo apt-add-repositor

KU Leuven Machine Learning Research Group 17 Dec 23, 2022
LibMTL: A PyTorch Library for Multi-Task Learning

LibMTL LibMTL is an open-source library built on PyTorch for Multi-Task Learning (MTL). See the latest documentation for detailed introductions and AP

765 Jan 06, 2023
Open Source Differentiable Computer Vision Library for PyTorch

Kornia is a differentiable computer vision library for PyTorch. It consists of a set of routines and differentiable modules to solve generic computer

kornia 7.6k Jan 04, 2023
Equivariant layers for RC-complement symmetry in DNA sequence data

Equi-RC Equivariant layers for RC-complement symmetry in DNA sequence data This is a repository that implements the layers as described in "Reverse-Co

7 May 19, 2022
Deep Sketch-guided Cartoon Video Inbetweening

Cartoon Video Inbetweening Paper | DOI | Video The source code of Deep Sketch-guided Cartoon Video Inbetweening by Xiaoyu Li, Bo Zhang, Jing Liao, Ped

Xiaoyu Li 37 Dec 22, 2022
Demonstration of the Model Training as a CI/CD System in Vertex AI

Model Training as a CI/CD System This project demonstrates the machine model training as a CI/CD system in GCP platform. You will see more detailed wo

Chansung Park 19 Dec 28, 2022
[NeurIPS-2021] Slow Learning and Fast Inference: Efficient Graph Similarity Computation via Knowledge Distillation

Efficient Graph Similarity Computation - (EGSC) This repo contains the source code and dataset for our paper: Slow Learning and Fast Inference: Effici

23 Nov 11, 2022
PyTorch implementation of "MLP-Mixer: An all-MLP Architecture for Vision" Tolstikhin et al. (2021)

mlp-mixer-pytorch PyTorch implementation of "MLP-Mixer: An all-MLP Architecture for Vision" Tolstikhin et al. (2021) Usage import torch from mlp_mixer

isaac 27 Jul 09, 2022
Change is Everywhere: Single-Temporal Supervised Object Change Detection in Remote Sensing Imagery (ICCV 2021)

Change is Everywhere Single-Temporal Supervised Object Change Detection in Remote Sensing Imagery by Zhuo Zheng, Ailong Ma, Liangpei Zhang and Yanfei

Zhuo Zheng 125 Dec 13, 2022
High level network definitions with pre-trained weights in TensorFlow

TensorNets High level network definitions with pre-trained weights in TensorFlow (tested with 2.1.0 = TF = 1.4.0). Guiding principles Applicability.

Taehoon Lee 1k Dec 13, 2022
Pytorch library for seismic data augmentation

Pytorch library for seismic data augmentation

Artemii Novoselov 27 Nov 22, 2022