End-To-End Memory Network using Tensorflow

Overview

MemN2N

Implementation of End-To-End Memory Networks with sklearn-like interface using Tensorflow. Tasks are from the bAbl dataset.

MemN2N picture

Get Started

git clone [email protected]:domluna/memn2n.git

mkdir ./memn2n/data/
cd ./memn2n/data/
wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz
tar xzvf ./tasks_1-20_v1-2.tar.gz

cd ../
python single.py

Examples

Running a single bAbI task

Running a joint model on all bAbI tasks

These files are also a good example of usage.

Requirements

  • tensorflow 1.0
  • scikit-learn 0.17.1
  • six 1.10.0

Single Task Results

For a task to pass it has to meet 95%+ testing accuracy. Measured on single tasks on the 1k data.

Pass: 1,4,12,15,20

Several other tasks have 80%+ testing accuracy.

Stochastic gradient descent optimizer was used with an annealed learning rate schedule as specified in Section 4.2 of End-To-End Memory Networks

The following params were used:

  • epochs: 100
  • hops: 3
  • embedding_size: 20
Task Training Accuracy Validation Accuracy Testing Accuracy
1 1.0 1.0 1.0
2 1.0 0.86 0.83
3 1.0 0.64 0.54
4 1.0 0.99 0.98
5 1.0 0.94 0.87
6 1.0 0.97 0.92
7 1.0 0.89 0.84
8 1.0 0.93 0.86
9 1.0 0.86 0.90
10 1.0 0.80 0.78
11 1.0 0.92 0.84
12 1.0 1.0 1.0
13 0.99 0.94 0.90
14 1.0 0.97 0.93
15 1.0 1.0 1.0
16 0.81 0.47 0.44
17 0.76 0.65 0.52
18 0.97 0.96 0.88
19 0.40 0.17 0.13
20 1.0 1.0 1.0

Joint Training Results

Pass: 1,6,9,10,12,13,15,20

Again stochastic gradient descent optimizer was used with an annealed learning rate schedule as specified in Section 4.2 of End-To-End Memory Networks

The following params were used:

  • epochs: 60
  • hops: 3
  • embedding_size: 40
Task Training Accuracy Validation Accuracy Testing Accuracy
1 1.0 0.99 0.999
2 1.0 0.84 0.849
3 0.99 0.72 0.715
4 0.96 0.86 0.851
5 1.0 0.92 0.865
6 1.0 0.97 0.964
7 0.96 0.87 0.851
8 0.99 0.89 0.898
9 0.99 0.96 0.96
10 1.0 0.96 0.928
11 1.0 0.98 0.93
12 1.0 0.98 0.982
13 0.99 0.98 0.976
14 1.0 0.81 0.877
15 1.0 1.0 0.983
16 0.64 0.45 0.44
17 0.77 0.64 0.547
18 0.85 0.71 0.586
19 0.24 0.07 0.104
20 1.0 1.0 0.996

Notes

Single task results are from 10 repeated trails of the single task model accross all 20 tasks with different random initializations. The performance of the model with the lowest validation accuracy for each task is shown in the table above.

Joint training results are from 10 repeated trails of the joint model accross all tasks. The performance of the single model whose validation accuracy passed the most tasks (>= 0.95) is shown in the table above (joint_scores_run2.csv). The scores from all 10 runs are located in the results/ directory.

Owner
Dominique Luna
magnificent stallion
Dominique Luna
This repository contains the code for designing risk bounded motion plans for car-like robot using Carla Simulator.

Nonlinear Risk Bounded Robot Motion Planning This code simulates the bicycle dynamics of car by steering it on the road by avoiding another static car

8 Sep 03, 2022
Official Pytorch Implementation of Unsupervised Image Denoising with Frequency Domain Knowledge

Unsupervised Image Denoising with Frequency Domain Knowledge (BMVC 2021 Oral) : Official Project Page This repository provides the official PyTorch im

Donggon Jang 12 Sep 26, 2022
Official Tensorflow implementation of U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation (ICLR 2020)

U-GAT-IT — Official TensorFlow Implementation (ICLR 2020) : Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization fo

Junho Kim 6.2k Jan 04, 2023
OpenMMLab Semantic Segmentation Toolbox and Benchmark.

Documentation: https://mmsegmentation.readthedocs.io/ English | 简体中文 Introduction MMSegmentation is an open source semantic segmentation toolbox based

OpenMMLab 5k Dec 31, 2022
Churn-Prediction-Project - In this project, a churn prediction model is developed for a private bank as a term project for Data Mining class.

Churn-Prediction-Project In this project, a churn prediction model is developed for a private bank as a term project for Data Mining class. Project in

1 Jan 03, 2022
Distributed Deep learning with Keras & Spark

Elephas: Distributed Deep Learning with Keras & Spark Elephas is an extension of Keras, which allows you to run distributed deep learning models at sc

Max Pumperla 1.6k Jan 05, 2023
Image restoration with neural networks but without learning.

Warning! The optimization may not converge on some GPUs. We've personally experienced issues on Tesla V100 and P40 GPUs. When running the code, make s

Dmitry Ulyanov 7.4k Jan 01, 2023
A TensorFlow implementation of the Mnemonic Descent Method.

MDM A Tensorflow implementation of the Mnemonic Descent Method. Mnemonic Descent Method: A recurrent process applied for end-to-end face alignment G.

123 Oct 07, 2022
A PyTorch implementation of the Transformer model in "Attention is All You Need".

Attention is all you need: A Pytorch Implementation This is a PyTorch implementation of the Transformer model in "Attention is All You Need" (Ashish V

Yu-Hsiang Huang 7.1k Jan 04, 2023
Look Who’s Talking: Active Speaker Detection in the Wild

Look Who's Talking: Active Speaker Detection in the Wild Dependencies pip install -r requirements.txt In addition to the Python dependencies, ffmpeg

Clova AI Research 60 Dec 08, 2022
Code for "MetaMorph: Learning Universal Controllers with Transformers", Gupta et al, ICLR 2022

MetaMorph: Learning Universal Controllers with Transformers This is the code for the paper MetaMorph: Learning Universal Controllers with Transformers

Agrim Gupta 50 Jan 03, 2023
🔮 Execution time predictions for deep neural network training iterations across different GPUs.

Habitat: A Runtime-Based Computational Performance Predictor for Deep Neural Network Training Habitat is a tool that predicts a deep neural network's

Geoffrey Yu 44 Dec 27, 2022
Simple image captioning model - CLIP prefix captioning.

Simple image captioning model - CLIP prefix captioning.

688 Jan 04, 2023
Malmo Collaborative AI Challenge - Team Pig Catcher

The Malmo Collaborative AI Challenge - Team Pig Catcher Approach The challenge involves 2 agents who can either cooperate or defect. The optimal polic

Kai Arulkumaran 66 Jun 29, 2022
Collections for the lasted paper about multi-view clustering methods (papers, codes)

Multi-View Clustering Papers Collections for the lasted paper about multi-view clustering methods (papers, codes). There also exists some repositories

Andrew Guan 10 Sep 20, 2022
Software Platform for solving and manipulating multiparametric programs in Python

PPOPT Python Parametric OPtimization Toolbox (PPOPT) is a software platform for solving and manipulating multiparametric programs in Python. This pack

10 Sep 13, 2022
Massively parallel Monte Carlo diffusion MR simulator written in Python.

Disimpy Disimpy is a Python package for generating simulated diffusion-weighted MR signals that can be useful in the development and validation of dat

Leevi 16 Nov 11, 2022
U-Net Implementation: Convolutional Networks for Biomedical Image Segmentation" using the Carvana Image Masking Dataset in PyTorch

U-Net Implementation By Christopher Ley This is my interpretation and implementation of the famous paper "U-Net: Convolutional Networks for Biomedical

Christopher Ley 1 Jan 06, 2022
The world's simplest facial recognition api for Python and the command line

Face Recognition You can also read a translated version of this file in Chinese 简体中文版 or in Korean 한국어 or in Japanese 日本語. Recognize and manipulate fa

Adam Geitgey 46.9k Jan 03, 2023
Who calls the shots? Rethinking Few-Shot Learning for Audio (WASPAA 2021)

rethink-audio-fsl This repo contains the source code for the paper "Who calls the shots? Rethinking Few-Shot Learning for Audio." (WASPAA 2021) Table

Yu Wang 34 Dec 24, 2022