Implementation of Vaswani, Ashish, et al. "Attention is all you need."

Overview

Attention Is All You Need Paper Implementation

This is my from-scratch implementation of the original transformer architecture from the following paper: Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems. 2017.

Table of Contents

About

"We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. " - Abstract

Transformers came to be a groundbreaking advance in neural network architectures which revolutionized what we can do with NLP and beyond. To name a few applications consider the application of BERT to Google search and GPT to Github Copilot. Those architectures are upgrades on the original transformer architecture described in this seminal paper. The goal of this repository is to provide an implementation that is easy to follow and understand while reading the paper. Setup is easy and everything is runnable on CPU for learning purposes.

✔️ Highly customizable configuration and training loop
✔️ Runnable on CPU and GPU
✔️ W&B integration for detailed logging of every metric
✔️ Pretrained models and their training details
✔️ Gradient Accumulation
✔️ Label smoothing
✔️ BPE and WordLevel Tokenizers
✔️ Dynamic Batching
✔️ Batch Dataset Processing
✔️ Bleu-score calculation during training
✔️ Documented dimensions for every step of the architecture
✔️ Shown progress of translation for an example after every epoch
✔️ Tutorial notebook (Coming soon...)

Setup

Environment

Using Miniconda/Anaconda:

  1. cd path_to_repo
  2. conda env create
  3. conda activate attention-is-all-you-need-paper

Note: Depending on your GPU you might need to switch cudatoolkit to version 10.2

Pretrained Models

To download the pretrained model and tokenizer run:

python scripts/download_pretrained.py

Note: If prompted about wandb setting select option 3

Usage

Training

Before starting training you can either choose a configuration out of available ones or create your own inside a single file src/config.py. The available parameters to customize, sorted by categories, are:

  • Run 🚅 :
    • RUN_NAME - Name of a training run
    • RUN_DESCRIPTION - Description of a training run
    • RUNS_FOLDER_PTH - Saving destination of a training run
  • Data 🔡 :
    • DATASET_SIZE - Number of examples you want to include from WMT14 en-de dataset (max 4,500,000)
    • TEST_PROPORTION - Test set proportion
    • MAX_SEQ_LEN - Maximum allowed sequence length
    • VOCAB_SIZE - Size of the vocabulary (good choice is dependant on the tokenizer)
    • TOKENIZER_TYPE - 'wordlevel' or 'bpe'
  • Training 🏋️‍♂️ :
    • BATCH_SIZE - Batch size
    • GRAD_ACCUMULATION_STEPS - Over how many batches to accumulate gradients before optimizing the parameters
    • WORKER_COUNT - Number of workers used in dataloaders
    • EPOCHS - Number of epochs
  • Optimizer 📉 :
    • BETAS - Adam beta parameter
    • EPS - Adam eps parameter
  • Scheduler ⏲️ :
    • N_WARMUP_STEPS - How many warmup steps to use in the scheduler
  • Model 🤖 :
    • D_MODEL - Model dimension
    • N_BLOCKS - Number of encoder and decoder blocks
    • N_HEADS - Number of heads in the Multi-Head attention mechanism
    • D_FF - Dimension of the Position Wise Feed Forward network
    • DROPOUT_PROBA - Dropout probability
  • Other 🧰 :
    • DEVICE - 'gpu' or 'cpu'
    • MODEL_SAVE_EPOCH_CNT - After how many epochs to save a model checkpoint
    • LABEL_SMOOTHING - Whether to apply label smoothing

Once you decide on the configuration edit the config_name in train.py and do:

$ cd src
$ python train.py

Inference

For inference I created a simple app with Streamlit which runs in your browser. Make sure to train or download the pretrained models beforehand. The app looks at the model directory for model and tokenizer checkpoints.

$ streamlit run app/inference_app.py
app.mp4

Data

Same WMT 2014 data is used for the English-to-German translation task. Dataset contains about 4,500,000 sentence pairs but you can manually specify the dataset size if you want to lower it and see some results faster. When training is initiated the dataset is automatically downloaded, preprocessed, tokenized and dataloaders are created. Also, a custom batch sampler is used for dynamic batching and padding of sentences of similar lengths which speeds up training. HuggingFace 🤗 datasets and tokenizers are used to achieve this very fast.

Architecture

The original transformer architecture presented in this paper consists of an encoder and decoder part purposely included to match the seq2seq problem type of machine translation. There are also encoder-only (e.g. BERT) and decoder-only (e.g. GPT) transformer architectures, those won't be covered here. One of the main features of transformers , in general, is parallelized sequence processing which RNN's lack. Main ingredient here is the attention mechanism which enables creating modified word representations (attention representations) that take into account the word's meaning in relation to other words in a sequence (e.g. the word "bank" can represent a financial institution or land along the edge of a river as in "river bank"). Depending on how we think about a word we may choose to represent it differently. This transcends the limits of traditional word embeddings.

For a detailed walkthrough of the architecture check the notebooks/tutorial.ipynb

Weights and Biases Logs

Weights and Biases is a very powerful tool for MLOps. I integrated it with this project to automatically provide very useful logs and visualizations when training. In fact, you can take a look at how the training looked for the pretrained models at this project link. All logs and visualizations are synced real time to the cloud.

When you start training you will be asked:

wandb: (1) Create W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 

For creating and syncing the visualizations to the cloud you will need a W&B account. Creating an account and using it won't take you more than a minute and it's free. If don't want to visualize results select option 3.

Citation

Please use this bibtex if you want to cite this repository:

@misc{Koch2021attentionisallyouneed,
  author = {Koch, Brando},
  title = {attention-is-all-you-need},
  year = {2021},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/bkoch4142/MISSING}},
}

License

This repository is under an MIT License

License: MIT

Owner
Brando Koch
Machine Learning Engineer with experience in ML, DL , NLP & CV specializing in ConversationalAI & NLP.
Brando Koch
Pytorch implementation of TailCalibX : Feature Generation for Long-tail Classification

TailCalibX : Feature Generation for Long-tail Classification by Rahul Vigneswaran, Marc T. Law, Vineeth N. Balasubramanian, Makarand Tapaswi [arXiv] [

Rahul Vigneswaran 34 Jan 02, 2023
A library for performing coverage guided fuzzing of neural networks

TensorFuzz: Coverage Guided Fuzzing for Neural Networks This repository contains a library for performing coverage guided fuzzing of neural networks,

Brain Research 195 Dec 28, 2022
Improving Contrastive Learning by Visualizing Feature Transformation, ICCV 2021 Oral

Improving Contrastive Learning by Visualizing Feature Transformation This project hosts the codes, models and visualization tools for the paper: Impro

Bingchen Zhao 83 Dec 15, 2022
Image marine sea litter prediction Shiny

MARLITE Shiny app for floating marine litter detection in aerial images. This directory contains the instructions and software needed to install the S

19 Dec 22, 2022
LEAP: Learning Articulated Occupancy of People

LEAP: Learning Articulated Occupancy of People Paper | Video | Project Page This is the official implementation of the CVPR 2021 submission LEAP: Lear

Neural Bodies 60 Nov 18, 2022
SwinIR: Image Restoration Using Swin Transformer

SwinIR: Image Restoration Using Swin Transformer This repository is the official PyTorch implementation of SwinIR: Image Restoration Using Shifted Win

Jingyun Liang 2.4k Jan 05, 2023
Builds a LoRa radio frequency fingerprint identification (RFFI) system based on deep learning techiniques

This project builds a LoRa radio frequency fingerprint identification (RFFI) system based on deep learning techiniques.

20 Dec 30, 2022
An implementation on "Curved-Voxel Clustering for Accurate Segmentation of 3D LiDAR Point Clouds with Real-Time Performance"

Lidar-Segementation An implementation on "Curved-Voxel Clustering for Accurate Segmentation of 3D LiDAR Point Clouds with Real-Time Performance" from

Wangxu1996 135 Jan 06, 2023
Kaggleship: Kaggle Notebooks

Kaggleship: Kaggle Notebooks This repository contains my Kaggle notebooks. They are generally about data science, machine learning, and deep learning.

Erfan Sobhaei 1 Jan 25, 2022
A Comprehensive Study on Learning-Based PE Malware Family Classification Methods

A Comprehensive Study on Learning-Based PE Malware Family Classification Methods Datasets Because of copyright issues, both the MalwareBazaar dataset

8 Oct 21, 2022
clustering moroccan stocks time series data using k-means with dtw (dynamic time warping)

Moroccan Stocks Clustering Context Hey! we don't always have to forecast time series am I right ? We use k-means to cluster about 70 moroccan stock pr

Ayman Lafaz 7 Oct 18, 2022
Lightweight tool to perform MITM attack on local network

ARPSpy - A lightweight tool to perform MITM attack Using many library to perform ARP Spoof and auto-sniffing HTTP packet containing credential. (Never

MinhItachi 8 Aug 28, 2022
Convolutional neural network web app trained to track our infant’s sleep schedule using our Google Nest camera.

Machine Learning Sleep Schedule Tracker What is it? Convolutional neural network web app trained to track our infant’s sleep schedule using our Google

g-parki 7 Jul 15, 2022
A data-driven maritime port simulator

PySeidon - A Data-Driven Maritime Port Simulator 🌊 Extendable and modular software for maritime port simulation. This software uses entity-component

6 Apr 10, 2022
MPLP: Metapath-Based Label Propagation for Heterogenous Graphs

MPLP: Metapath-Based Label Propagation for Heterogenous Graphs Results on MAG240M Here, we demonstrate the following performance on the MAG240M datase

Qiuying Peng 10 Jun 28, 2022
SPT_LSA_ViT - Implementation for Visual Transformer for Small-size Datasets

Vision Transformer for Small-Size Datasets Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song | Paper Inha University Abstract Recently, the Vision

Lee SeungHoon 87 Jan 01, 2023
In this project we investigate the performance of the SetCon model on realistic video footage. Therefore, we implemented the model in PyTorch and tested the model on two example videos.

Contrastive Learning of Object Representations Supervisor: Prof. Dr. Gemma Roig Institutions: Goethe University CVAI - Computational Vision & Artifici

Dirk Neuhäuser 6 Dec 08, 2022
🥇Samsung AI Challenge 2021 1등 솔루션입니다🥇

MoT - Molecular Transformer Large-scale Pretraining for Molecular Property Prediction Samsung AI Challenge for Scientific Discovery This repository is

Jungwoo Park 44 Dec 03, 2022
Monitor your ML jobs on mobile devices📱, especially for Google Colab / Kaggle

TF Watcher TF Watcher is a simple to use Python package and web app which allows you to monitor 👀 your Machine Learning training or testing process o

Rishit Dagli 54 Nov 01, 2022
Keyhole Imaging: Non-Line-of-Sight Imaging and Tracking of Moving Objects Along a Single Optical Path

Keyhole Imaging Code & Dataset Code associated with the paper "Keyhole Imaging: Non-Line-of-Sight Imaging and Tracking of Moving Objects Along a Singl

Stanford Computational Imaging Lab 20 Feb 03, 2022