An implementation of the Pay Attention when Required transformer

Overview

Pay Attention when Required (PAR) Transformer-XL

An implementation of the Pay Attention when Required transformer from the paper: https://arxiv.org/pdf/2009.04534.pdf

alt text [source: Jonathan Kernes]

Quick overview

The Pay Attention when Required Transformer (Mandava, et. al. 2020) is just a regular transformer-XL (Dai et. al. 2019)[https://arxiv.org/pdf/1901.02860.pdf] , but the ratio of attention and dense layers has been optimized. This optimization is performed by allowing the network to choose which types of layer it prefers in each block of the network. The present implementation is not an exact replica of the author's efforts. Instead, we perform a simultaneous optimization procedure on both the model architecture and model parameters. The search is performed using a SuperNet, which is a sequential neural network composed of stochastic blocks, as shown in the figure below (taken from the paper. Please don't sue me!)

alt text [Source: Mandava et. al. 2020]

The key component is a Gumbel-Softmax layer [(Jang et al., 2016) and (Maddison et al., 2016). jang link: https://arxiv.org/pdf/1611.01144.pdf]. This layer is a continuous representation of a discrete sampling from a Categorical distribution, thereby allowing us to use gradients to learn parameters of a discrete distribution. (Recall a categorical is a distrbution over K states with kth state having probability pi_k, and we must have the normalization condition \sum_{i=1}^K pi_i = 1)

As the model learns, it is free to adjust both the usual model parameters, as well as its architecture search parameters pi, indicating the probability of choosing either

  1. Attention

  2. Dense

  3. Identity

for any given stochastic block. We perform simulated annealing: since the categorical distribution is approximated by a continuous representation, we get some scores like (0.02, 0.98, 0.02) for the probability of say sampling that state 2 is picked. The sharpness of this is set by a parameter \tau (the temperature), with a categorical distribution the limit tau-->0. Simulated annealing means we begin with tau=1 to let the model figure out what it wants, then slowly decrease tau so the distribution approaches a categorical.

All of this is implemented on the freely available wiki-text2 dataset.

Explanation of the main GIF: The main gif is the result of our experiments. It shows the pi distribution for each stochastic block of a 6 block SuperNet, as a function of training iterations. The number indicates the probability of the most likely layer type (darker means more probable). As you can see, the model learns to put attention in the beginning, and dense layers at the end.

Requirements

Usual ML stuff, if you have a conda environment, python 3+, TensorFlow 2+ you should be ok. You will need TensorFlow Text as well to handle the SentencePiece Tokenization

If you choose to run your own tokenizer (a flag option in data_utils for handling new text data), you will also need to download the SentencePiece package: https://github.com/google/sentencepiece

Data

The dataset used is Wiki-text2. We have provided a copy of this in the data folder, along with some preprocessed data for training. In order to reproduce this from scratch, run the shell script

./create_tfrecords.sh

This will download the wiki-text2 dataset from its source, then proceed to clean, batch, and write the data to a tfrecords file. The shell script calls build_data.py which offers more control over what type of data to generate. The general parameters you will want to tune are:

*batch_size *seq_len.

You can also supply your own dataset instead of the one provided. The underlying tokenizer uses sentencepiece (Kudo): https://github.com/google/sentencepiece, which works at the byte level and can handle any kind of input. Simply change the --input_text flag to your file, and set the desired --vocab_size.

Why do we need to specify the batch size? Transformer XL uses memory states to form a recurrent, long range network. After analyzing a particular sequence say [A,B] of the sequence [A,B,C,D], the results of [A,B] are fed into the [C,D] calculation with a stop gradient. Therefore, we must be sure that each datapoint follows chronologically from the previous one.

This is achieved by context batching (see data_utils.py function) where we break the entire dataset into batch_size segments, then pull in order one sequence from each batch at a time to form the dataset. Because of this, note that adding more shards to the data could result in a large loss (order of batch_size*seq_len*shards), as each shard will drop the remaining datapoint of size (batch_size*seq_len) to keep the tensor shapes.

Addtional technical details

Per the original Transformer-XL, we also implement an adaptive softmax layer (Grave et. al. 2017, https://arxiv.org/abs/1609.04309) to deal with a potentially large number of outputs in the final dense layer. This implemenation is inspired by the TF 1.0 example at https://github.com/yangsaiyong/tf-adaptive-softmax-lstm-lm. To use the adaptive softmax, set the --cutoffs= flag in train.py. The cutoffs are the max values of each bin, and should NOT include the vocab size (i.e. the max cutoff of the final bin). If no cutoffs are specified, the model defaults to normal softmax.

For completeness, we have also provided a script optimal_cuts.py that determines the optimal cutoffs given a return space separated file of unigram probabilities (based on the assumptions of Grave et. al. regarding GPU computation complexity -- see the paper for details). The algorithm uses dynamic programming, but is quite slow at O(KN^2), for K cutoffs and N vocab words. In principle it's a one time cost to determine the cutoffs, but we are impatient and recommend to just play around with the cutoffs instead. See the script for flag details

Training and Benchmarks

The default model we use has memory length 16, feed-forward dimension 1024, attention dimension 128, and 6 stochastic blocks, with an adaptive softmax layer and 2 clusters. We trained on a colab GPU for 20 epochs, taking a total of 37 minutes. We use an Adam optimzer with cosine rate decay: an initial warmup of 4000 steps and a maximum learning rate of 1e-4, decaying to zero at the end of training. Our training benchmarks are:

Iteration (thousands) Train_perplexity Validation_perplexity Time
2.7k 163.9 114.4 1m 58s
8.5k 78.56 62.33 5m 37s
14.1k 65.71 51.88 9m 28s
28.3k 48.52 42.61 18m 40s
48.1k 41.85 39.57 31m 51s
56.5k 42.12 39.41 37m 14s

To train, simply run the shell script

./base_model.sh

adjusting the parameters as you see fit. The above model is the default configuration. To train in colab, simply open up the notebook "colab.ipynb" and follow the instructions. This is most easily done by going to [google.colab.com] and searching this repository in github. The benefit of colab, is it's easier to play around with the model after training.

While training, we have provided two ways to monitor the output

  1. A tensorboard log. The colab notebook takes care of running this for you. In the terminal, first create a 'logs' directory, then run the command tensorboard --logdir logs in a separate tab. This will open a port where you can view live plots of the learning rate, tau annealing, train/valid loss and perplexity.

  2. An output log saved to training_log.log. This will log the model summary, parameters, etc. as well as print out loss updates every 100 steps and save it to the log file.

Thanks for reading this far!

Enjoy! And thank you to the wonderful researchers that inspired this project.

If you would like to contribute, or have any comments questions concerns please open a pull request or email me directly.

TLA - Twitter Linguistic Analysis

TLA - Twitter Linguistic Analysis Tool for linguistic analysis of communities TLA is built using PyTorch, Transformers and several other State-of-the-

Tushar Sarkar 47 Aug 14, 2022
AEC_DeepModel - Deep learning based acoustic echo cancellation baseline code

AEC_DeepModel - Deep learning based acoustic echo cancellation baseline code

凌逆战 75 Dec 05, 2022
Curso práctico: NLP de cero a cien 🤗

Curso Práctico: NLP de cero a cien Comprende todos los conceptos y arquitecturas clave del estado del arte del NLP y aplícalos a casos prácticos utili

Somos NLP 147 Jan 06, 2023
A 10000+ hours dataset for Chinese speech recognition

A 10000+ hours dataset for Chinese speech recognition

309 Dec 16, 2022
PyTorch implementation of NATSpeech: A Non-Autoregressive Text-to-Speech Framework

A Non-Autoregressive Text-to-Speech (NAR-TTS) framework, including official PyTorch implementation of PortaSpeech (NeurIPS 2021) and DiffSpeech (AAAI 2022)

760 Jan 03, 2023
Code voor mijn Master project omtrent VideoBERT

Code voor masterproef Deze repository bevat de code voor het project van mijn masterproef omtrent VideoBERT. De code in deze repository is gebaseerd o

35 Oct 18, 2021
STonKGs is a Sophisticated Transformer that can be jointly trained on biomedical text and knowledge graphs

STonKGs STonKGs is a Sophisticated Transformer that can be jointly trained on biomedical text and knowledge graphs. This multimodal Transformer combin

STonKGs 27 Aug 11, 2022
gaiic2021-track3-小布助手对话短文本语义匹配复赛rank3、决赛rank4

决赛答辩已经过去一段时间了,我们队伍ac milan最终获得了复赛第3,决赛第4的成绩。在此首先感谢一些队友的carry~ 经过2个多月的比赛,学习收获了很多,也认识了很多大佬,在这里记录一下自己的参赛体验和学习收获。

102 Dec 19, 2022
A python script that will use hydra to get user and password to login to ssh, ftp, and telnet

Hydra-Auto-Hack A python script that will use hydra to get user and password to login to ssh, ftp, and telnet Project Description This python script w

2 Jan 16, 2022
Script to generate VAD dataset used in Asteroid recipe

About the dataset LibriVAD is an open source dataset for voice activity detection in noisy environments. It is derived from LibriSpeech signals (clean

11 Sep 15, 2022
A collection of Classical Chinese natural language processing models, including Classical Chinese related models and resources on the Internet.

GuwenModels: 古文自然语言处理模型合集, 收录互联网上的古文相关模型及资源. A collection of Classical Chinese natural language processing models, including Classical Chinese related models and resources on the Internet.

Ethan 66 Dec 26, 2022
Random-Word-Generator - Generates meaningful words from dictionary with given no. of letters and words.

Random Word Generator Generates meaningful words from dictionary with given no. of letters and words. This might be useful for generating short links

Mohammed Rabil 1 Jan 01, 2022
A PyTorch implementation of the Transformer model in "Attention is All You Need".

Attention is all you need: A Pytorch Implementation This is a PyTorch implementation of the Transformer model in "Attention is All You Need" (Ashish V

Yu-Hsiang Huang 7.1k Jan 05, 2023
Search for documents in a domain through Google. The objective is to extract metadata

MetaFinder - Metadata search through Google _____ __ ___________ .__ .___ / \

Josué Encinar 85 Dec 16, 2022
Syntax-aware Multi-spans Generation for Reading Comprehension (TASLP 2022)

SyntaxGen Syntax-aware Multi-spans Generation for Reading Comprehension (TASLP 2022) In this repo, we upload all the scripts for this work. Due to siz

Zhuosheng Zhang 3 Jun 13, 2022
构建一个多源(公众号、RSS)、干净、个性化的阅读环境

2C 构建一个多源(公众号、RSS)、干净、个性化的阅读环境 作为一名微信公众号的重度用户,公众号一直被我设为汲取知识的地方。随着使用程度的增加,相信大家或多或少会有一个比较头疼的问题——广告问题。 假设你关注的公众号有十来个,若一个公众号两周接一次广告,理论上你会面临二十多次广告,实际上会更多,运

howie.hu 678 Dec 28, 2022
A library for finding knowledge neurons in pretrained transformer models.

knowledge-neurons An open source repository replicating the 2021 paper Knowledge Neurons in Pretrained Transformers by Dai et al., and extending the t

EleutherAI 96 Dec 21, 2022
Statistics and Mathematics for Machine Learning, Deep Learning , Deep NLP

Stat4ML Statistics and Mathematics for Machine Learning, Deep Learning , Deep NLP This is the first course from our trio courses: Statistics Foundatio

Omid Safarzadeh 83 Dec 29, 2022
Chinese NER(Named Entity Recognition) using BERT(Softmax, CRF, Span)

Chinese NER(Named Entity Recognition) using BERT(Softmax, CRF, Span)

Weitang Liu 1.6k Jan 03, 2023
A curated list of efficient attention modules

awesome-fast-attention A curated list of efficient attention modules

Sepehr Sameni 891 Dec 22, 2022