Implementation of Vaswani, Ashish, et al. "Attention is all you need."

Overview

Attention Is All You Need Paper Implementation

This is my from-scratch implementation of the original transformer architecture from the following paper: Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems. 2017.

Table of Contents

About

"We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. " - Abstract

Transformers came to be a groundbreaking advance in neural network architectures which revolutionized what we can do with NLP and beyond. To name a few applications consider the application of BERT to Google search and GPT to Github Copilot. Those architectures are upgrades on the original transformer architecture described in this seminal paper. The goal of this repository is to provide an implementation that is easy to follow and understand while reading the paper. Setup is easy and everything is runnable on CPU for learning purposes.

✔️ Highly customizable configuration and training loop
✔️ Runnable on CPU and GPU
✔️ W&B integration for detailed logging of every metric
✔️ Pretrained models and their training details
✔️ Gradient Accumulation
✔️ Label smoothing
✔️ BPE and WordLevel Tokenizers
✔️ Dynamic Batching
✔️ Batch Dataset Processing
✔️ Bleu-score calculation during training
✔️ Documented dimensions for every step of the architecture
✔️ Shown progress of translation for an example after every epoch
✔️ Tutorial notebook (Coming soon...)

Setup

Environment

Using Miniconda/Anaconda:

  1. cd path_to_repo
  2. conda env create
  3. conda activate attention-is-all-you-need-paper

Note: Depending on your GPU you might need to switch cudatoolkit to version 10.2

Pretrained Models

To download the pretrained model and tokenizer run:

python scripts/download_pretrained.py

Note: If prompted about wandb setting select option 3

Usage

Training

Before starting training you can either choose a configuration out of available ones or create your own inside a single file src/config.py. The available parameters to customize, sorted by categories, are:

  • Run 🚅 :
    • RUN_NAME - Name of a training run
    • RUN_DESCRIPTION - Description of a training run
    • RUNS_FOLDER_PTH - Saving destination of a training run
  • Data 🔡 :
    • DATASET_SIZE - Number of examples you want to include from WMT14 en-de dataset (max 4,500,000)
    • TEST_PROPORTION - Test set proportion
    • MAX_SEQ_LEN - Maximum allowed sequence length
    • VOCAB_SIZE - Size of the vocabulary (good choice is dependant on the tokenizer)
    • TOKENIZER_TYPE - 'wordlevel' or 'bpe'
  • Training 🏋️‍♂️ :
    • BATCH_SIZE - Batch size
    • GRAD_ACCUMULATION_STEPS - Over how many batches to accumulate gradients before optimizing the parameters
    • WORKER_COUNT - Number of workers used in dataloaders
    • EPOCHS - Number of epochs
  • Optimizer 📉 :
    • BETAS - Adam beta parameter
    • EPS - Adam eps parameter
  • Scheduler ⏲️ :
    • N_WARMUP_STEPS - How many warmup steps to use in the scheduler
  • Model 🤖 :
    • D_MODEL - Model dimension
    • N_BLOCKS - Number of encoder and decoder blocks
    • N_HEADS - Number of heads in the Multi-Head attention mechanism
    • D_FF - Dimension of the Position Wise Feed Forward network
    • DROPOUT_PROBA - Dropout probability
  • Other 🧰 :
    • DEVICE - 'gpu' or 'cpu'
    • MODEL_SAVE_EPOCH_CNT - After how many epochs to save a model checkpoint
    • LABEL_SMOOTHING - Whether to apply label smoothing

Once you decide on the configuration edit the config_name in train.py and do:

$ cd src
$ python train.py

Inference

For inference I created a simple app with Streamlit which runs in your browser. Make sure to train or download the pretrained models beforehand. The app looks at the model directory for model and tokenizer checkpoints.

$ streamlit run app/inference_app.py
app.mp4

Data

Same WMT 2014 data is used for the English-to-German translation task. Dataset contains about 4,500,000 sentence pairs but you can manually specify the dataset size if you want to lower it and see some results faster. When training is initiated the dataset is automatically downloaded, preprocessed, tokenized and dataloaders are created. Also, a custom batch sampler is used for dynamic batching and padding of sentences of similar lengths which speeds up training. HuggingFace 🤗 datasets and tokenizers are used to achieve this very fast.

Architecture

The original transformer architecture presented in this paper consists of an encoder and decoder part purposely included to match the seq2seq problem type of machine translation. There are also encoder-only (e.g. BERT) and decoder-only (e.g. GPT) transformer architectures, those won't be covered here. One of the main features of transformers , in general, is parallelized sequence processing which RNN's lack. Main ingredient here is the attention mechanism which enables creating modified word representations (attention representations) that take into account the word's meaning in relation to other words in a sequence (e.g. the word "bank" can represent a financial institution or land along the edge of a river as in "river bank"). Depending on how we think about a word we may choose to represent it differently. This transcends the limits of traditional word embeddings.

For a detailed walkthrough of the architecture check the notebooks/tutorial.ipynb

Weights and Biases Logs

Weights and Biases is a very powerful tool for MLOps. I integrated it with this project to automatically provide very useful logs and visualizations when training. In fact, you can take a look at how the training looked for the pretrained models at this project link. All logs and visualizations are synced real time to the cloud.

When you start training you will be asked:

wandb: (1) Create W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 

For creating and syncing the visualizations to the cloud you will need a W&B account. Creating an account and using it won't take you more than a minute and it's free. If don't want to visualize results select option 3.

Citation

Please use this bibtex if you want to cite this repository:

@misc{Koch2021attentionisallyouneed,
  author = {Koch, Brando},
  title = {attention-is-all-you-need},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/bkoch4142/MISSING}},
}

License

This repository is under an MIT License

License: MIT

Owner
Brando Koch
Machine Learning Engineer with experience in ML, DL , NLP & CV specializing in ConversationalAI & NLP.
Brando Koch
This is a computer vision based implementation of the popular childhood game 'Hand Cricket/Odd or Even' in python

Hand Cricket Table of Content Overview Installation Game rules Project Details Future scope Overview This is a computer vision based implementation of

Abhinav R Nayak 6 Jan 12, 2022
WRENCH: Weak supeRvision bENCHmark

🔧 What is it? Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development

Jieyu Zhang 176 Dec 28, 2022
Repo for 2021 SDD assessment task 2, by Felix, Anna, and James.

SoftwareTask2 Repo for 2021 SDD assessment task 2, by Felix, Anna, and James. File/folder structure: helloworld.py - demonstrates various map backgrou

3 Dec 13, 2022
DNA-RECON { Automatic Web Reconnaissance Tool }

ABOUT TOOL : DNA-RECON is an automatic web reconnaissance tool written in python. This tool made for reconnaissance and information gathering with an

NIKUNJ BHATT 25 Aug 11, 2021
AdelaiDepth is an open source toolbox for monocular depth prediction.

AdelaiDepth is an open source toolbox for monocular depth prediction.

Adelaide Intelligent Machines (AIM) Group 743 Jan 01, 2023
Winners of DrivenData's Overhead Geopose Challenge

Winners of DrivenData's Overhead Geopose Challenge

DrivenData 22 Aug 04, 2022
TJU Deep Learning & Neural Network

Deep_Learning & Neural_Network_Lab 实验环境 Python 3.9 Anaconda3(官网下载或清华镜像都行) PyTorch 1.10.1(安装代码如下) conda install pytorch torchvision torchaudio cudatool

St3ve Lee 1 Jan 19, 2022
Human segmentation models, training/inference code, and trained weights, implemented in PyTorch

Human-Segmentation-PyTorch Human segmentation models, training/inference code, and trained weights, implemented in PyTorch. Supported networks UNet: b

Thuy Ng 474 Dec 19, 2022
Official Implementation of Swapping Autoencoder for Deep Image Manipulation (NeurIPS 2020)

Swapping Autoencoder for Deep Image Manipulation Taesung Park, Jun-Yan Zhu, Oliver Wang, Jingwan Lu, Eli Shechtman, Alexei A. Efros, Richard Zhang UC

449 Dec 27, 2022
This repository consists of Blender python scripts and corresponding assets to generate variants of the CANDLE dataset

candle-simulator This repository consists of Blender python scripts and corresponding assets to generate variants of the IITH-CANDLE dataset. The rend

1 Dec 15, 2021
Doing the asl sign language classification on static images using graph neural networks.

SignLangGNN When GNNs 💜 MediaPipe. This is a starter project where I tried to implement some traditional image classification problem i.e. the ASL si

10 Nov 09, 2022
All the code and files related to the MI-Lab of UE19CS305 course in sem 5

Machine-Intelligence-Lab-CS305 The compilation of all the code an drelated files from MI-Lab UE19CS305 (of batch 2019-2023) offered by PES University

Arvind Krishna 3 Nov 10, 2022
一些经典的CTR算法的复现; LR, FM, FFM, AFM, DeepFM,xDeepFM, PNN, DCN, DCNv2, DIFM, AutoInt, FiBiNet,AFN,ONN,DIN, DIEN ... (pytorch, tf2.0)

CTR Algorithm 根据论文, 博客, 知乎等方式学习一些CTR相关的算法 理解原理并自己动手来实现一遍 pytorch & tf2.0 保持一颗学徒的心! Schedule Model pytorch tensorflow2.0 paper LR ✔️ ✔️ \ FM ✔️ ✔️ Fac

luo han 149 Dec 20, 2022
Learning Visual Words for Weakly-Supervised Semantic Segmentation

[IJCAI 2021] Learning Visual Words for Weakly-Supervised Semantic Segmentation Implementation of IJCAI 2021 paper Learning Visual Words for Weakly-Sup

Lixiang Ru 24 Oct 05, 2022
A Review of Deep Learning Techniques for Markerless Human Motion on Synthetic Datasets

HOW TO USE THIS PROJECT A Review of Deep Learning Techniques for Markerless Human Motion on Synthetic Datasets Based on DeepLabCut toolbox, we run wit

1 Jan 10, 2022
HeatNet is a python package that provides tools to build, train and evaluate neural networks designed to predict extreme heat wave events globally on daily to subseasonal timescales.

HeatNet HeatNet is a python package that provides tools to build, train and evaluate neural networks designed to predict extreme heat wave events glob

Google Research 6 Jul 07, 2022
It is the assignment for COMP 576 in Rice University

COMP-576 It is the assignment for COMP 576 in Rice University There are two programming assignments and one Final Project. Assignment 1: It is a MLP a

Maojie Tang 1 Nov 25, 2021
Code + pre-trained models for the paper Keeping Your Eye on the Ball Trajectory Attention in Video Transformers

Motionformer This is an official pytorch implementation of paper Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers. In this rep

Facebook Research 192 Dec 23, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Official code for "Distributed Deep Learning in Open Collaborations" (NeurIPS 2021)

Distributed Deep Learning in Open Collaborations This repository contains the code for the NeurIPS 2021 paper "Distributed Deep Learning in Open Colla

Yandex Research 96 Sep 15, 2022