Deeply Supervised, Layer-wise Prediction-aware (DSLP) Transformer for Non-autoregressive Neural Machine Translation

Overview

Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision

Training Efficiency

We show the training efficiency of our DSLP model based on vanilla NAT model. Specifically, we compared the BLUE socres of vanilla NAT and vanilla NAT with DSLP & Mixed Training on the same traning time (in hours).

As we observed, our DSLP model achieves much higher BLUE scores shortly after the training started (~3 hours). It shows that our DSLP is much more efficient in training, as our model ahieves higher BLUE scores with the same amount of training cost.

Efficiency

We run the experiments with 8 Tesla V100 GPUs. The batch size is 128K tokens, and each model is trained with 300K updates.

Replication

We provide the scripts of replicating the results on WMT'14 EN-DE task.

Dataset

We download the distilled data from FairSeq

Preprocessed by

TEXT=wmt14_ende_distill
python3 fairseq_cli/preprocess.py --source-lang en --target-lang de \
   --trainpref $TEXT/train.en-de --validpref $TEXT/valid.en-de --testpref $TEXT/test.en-de \
   --destdir data-bin/wmt14.en-de_kd --workers 40 --joined-dictionary

Training:

GLAT with DSLP

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_glat --criterion glat_loss --arch glat_sd --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 --glat-mode glat 

CMLM with DSLP

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch glat_sd --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 

Vanilla NAT with DSLP

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_sd --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 

Vanilla NAT with DSLP and Mixed Training:

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_sd --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192  --ss-ratio 0.3 --fixed-ss-ratio --masked-loss

CTC with DSLP:

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_ctc_sd --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 

CTC with DSLP and Mixed Training:

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_ctc_sd_ss --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 --ss-ratio 0.3 --fixed-ss-ratio

Evaluation

fairseq-generate data-bin/wmt14.en-de_kd  --path PATH_TO_A_CHECKPOINT \
    --gen-subset test --task translation_lev --iter-decode-max-iter 0 \
    --iter-decode-eos-penalty 0 --beam 1 --remove-bpe --print-step --batch-size 100

Note: 1) Add --plain-ctc --model-overrides '{"ctc_beam_size": 1, "plain_ctc": True}' if it is CTC based; 2) Change the task to translation_glat if it is GLAT based.

Output

We in addition provide the output of CTC w/ DSLP, CTC w/ DSLP & Mixed Training, Vanilla NAT w/ DSLP, Vanilla NAT w/ DSLP with Mixed Training, GLAT w/ DSLP, and CMLM w/ DSLP for review purpose.

Model Reference Hypothesis
CTC w/ DSLP ref hyp
CTC w/ DSLP & Mixed Training ref hyp
Vanilla NAT w/ DSLP ref hyp
Vanilla NAT w/ DSLP & Mixed Training ref hyp
GLAT w/ DSLP ref hyp
CMLM w/ DSLP ref hyp

Note: The output is on WMT'14 EN-DE. The references are paired with hypotheses for each model.

Owner
Chenyang Huang
Stay hungry, stay foolish
Chenyang Huang
nlabel is a library for generating, storing and retrieving tagging information and embedding vectors from various nlp libraries through a unified interface.

nlabel is a library for generating, storing and retrieving tagging information and embedding vectors from various nlp libraries through a unified interface.

Bernhard Liebl 2 Jun 10, 2022
IMS-Toucan is a toolkit to train state-of-the-art Speech Synthesis models

IMS-Toucan is a toolkit to train state-of-the-art Speech Synthesis models. Everything is pure Python and PyTorch based to keep it as simple and beginner-friendly, yet powerful as possible.

Digital Phonetics at the University of Stuttgart 247 Jan 05, 2023
Guide: Finetune GPT2-XL (1.5 Billion Parameters) and GPT-NEO (2.7 B) on a single 16 GB VRAM V100 Google Cloud instance with Huggingface Transformers using DeepSpeed

Guide: Finetune GPT2-XL (1.5 Billion Parameters) and GPT-NEO (2.7 Billion Parameters) on a single 16 GB VRAM V100 Google Cloud instance with Huggingfa

289 Jan 06, 2023
Finds snippets in iambic pentameter in English-language text and tries to combine them to a rhyming sonnet.

Sonnet finder Finds snippets in iambic pentameter in English-language text and tries to combine them to a rhyming sonnet. Usage This is a Python scrip

Marcel Bollmann 11 Sep 25, 2022
HiFi DeepVariant + WhatsHap workflowHiFi DeepVariant + WhatsHap workflow

HiFi DeepVariant + WhatsHap workflow Workflow steps align HiFi reads to reference with pbmm2 call small variants with DeepVariant, using two-pass meth

William Rowell 2 May 14, 2022
History Aware Multimodal Transformer for Vision-and-Language Navigation

History Aware Multimodal Transformer for Vision-and-Language Navigation This repository is the official implementation of History Aware Multimodal Tra

Shizhe Chen 46 Nov 23, 2022
An IVR Chatbot which can exponentially reduce the burden of companies as well as can improve the consumer/end user experience.

IVR-Chatbot Achievements 🏆 Team Uhtred won the Maverick 2.0 Bot-a-thon 2021 organized by AbInbev India. ❓ Problem Statement As we all know that, lot

ARYAMAAN PANDEY 9 Dec 08, 2022
SentAugment is a data augmentation technique for semi-supervised learning in NLP.

SentAugment SentAugment is a data augmentation technique for semi-supervised learning in NLP. It uses state-of-the-art sentence embeddings to structur

Meta Research 363 Dec 30, 2022
📝An easy-to-use package to restore punctuation of the text.

✏️ rpunct - Restore Punctuation This repo contains code for Punctuation restoration. This package is intended for direct use as a punctuation restorat

Daulet Nurmanbetov 72 Dec 30, 2022
Topic Modelling for Humans

gensim – Topic Modelling in Python Gensim is a Python library for topic modelling, document indexing and similarity retrieval with large corpora. Targ

RARE Technologies 13.8k Jan 02, 2023
Unet-TTS: Improving Unseen Speaker and Style Transfer in One-shot Voice Cloning

Unet-TTS: Improving Unseen Speaker and Style Transfer in One-shot Voice Cloning English | 中文 ❗ Now we provide inferencing code and pre-training models

164 Jan 02, 2023
ADCS cert template modification and ACL enumeration

Purpose This tool is designed to aid an operator in modifying ADCS certificate templates so that a created vulnerable state can be leveraged for privi

Fortalice Solutions, LLC 78 Dec 12, 2022
ANTLR (ANother Tool for Language Recognition) is a powerful parser generator for reading, processing, executing, or translating structured text or binary files.

ANTLR (ANother Tool for Language Recognition) is a powerful parser generator for reading, processing, executing, or translating structured text or binary files.

Antlr Project 13.6k Jan 05, 2023
Translators - is a library which aims to bring free, multiple, enjoyable translation to individuals and students in Python

Translators - is a library which aims to bring free, multiple, enjoyable translation to individuals and students in Python

UlionTse 907 Dec 27, 2022
Sequence-to-Sequence learning using PyTorch

Seq2Seq in PyTorch This is a complete suite for training sequence-to-sequence models in PyTorch. It consists of several models and code to both train

Elad Hoffer 514 Nov 17, 2022
Score-Based Point Cloud Denoising (ICCV'21)

Score-Based Point Cloud Denoising (ICCV'21) [Paper] https://arxiv.org/abs/2107.10981 Installation Recommended Environment The code has been tested in

Shitong Luo 79 Dec 26, 2022
This is the Alpha of Nutte language, she is not complete yet / Essa é a Alpha da Nutte language, não está completa ainda

nutte-language This is the Alpha of Nutte language, it is not complete yet / Essa é a Alpha da Nutte language, não está completa ainda My language was

catdochrome 2 Dec 18, 2021
A CRM department in a local bank works on classify their lost customers with their past datas. So they want predict with these method that average loss balance and passive duration for future.

Rule-Based-Classification-in-a-Banking-Case. A CRM department in a local bank works on classify their lost customers with their past datas. So they wa

ÖMER YILDIZ 4 Mar 20, 2022
PyWorld3 is a Python implementation of the World3 model

The World3 model revisited in Python Install & Hello World3 How to tune your own simulation Licence How to cite PyWorld3 with Bibtex References & ackn

Charles Vanwynsberghe 248 Dec 14, 2022
CMeEE 数据集医学实体抽取

医学实体抽取_GlobalPointer_torch 介绍 思想来自于苏神 GlobalPointer,原始版本是基于keras实现的,模型结构实现参考现有 pytorch 复现代码【感谢!】,基于torch百分百复现苏神原始效果。 数据集 中文医学命名实体数据集 点这里申请,很简单,共包含九类医学

85 Dec 28, 2022