Deeply Supervised, Layer-wise Prediction-aware (DSLP) Transformer for Non-autoregressive Neural Machine Translation

Overview

Non-Autoregressive Translation with Layer-Wise Prediction and Deep Supervision

Training Efficiency

We show the training efficiency of our DSLP model based on vanilla NAT model. Specifically, we compared the BLUE socres of vanilla NAT and vanilla NAT with DSLP & Mixed Training on the same traning time (in hours).

As we observed, our DSLP model achieves much higher BLUE scores shortly after the training started (~3 hours). It shows that our DSLP is much more efficient in training, as our model ahieves higher BLUE scores with the same amount of training cost.

Efficiency

We run the experiments with 8 Tesla V100 GPUs. The batch size is 128K tokens, and each model is trained with 300K updates.

Replication

We provide the scripts of replicating the results on WMT'14 EN-DE task.

Dataset

We download the distilled data from FairSeq

Preprocessed by

TEXT=wmt14_ende_distill
python3 fairseq_cli/preprocess.py --source-lang en --target-lang de \
   --trainpref $TEXT/train.en-de --validpref $TEXT/valid.en-de --testpref $TEXT/test.en-de \
   --destdir data-bin/wmt14.en-de_kd --workers 40 --joined-dictionary

Training:

GLAT with DSLP

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_glat --criterion glat_loss --arch glat_sd --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 --glat-mode glat 

CMLM with DSLP

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch glat_sd --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 

Vanilla NAT with DSLP

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_sd --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 

Vanilla NAT with DSLP and Mixed Training:

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_sd --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  --label-smoothing 0.1 \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192  --ss-ratio 0.3 --fixed-ss-ratio --masked-loss

CTC with DSLP:

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_ctc_sd --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 

CTC with DSLP and Mixed Training:

python3 train.py data-bin/wmt14.en-de_kd --source-lang en --target-lang de  --save-dir checkpoints  --eval-tokenized-bleu \
   --keep-interval-updates 5 --save-interval-updates 500 --validate-interval-updates 500 --maximize-best-checkpoint-metric \
   --eval-bleu-remove-bpe --eval-bleu-print-samples --best-checkpoint-metric bleu --log-format simple --log-interval 100 \
   --eval-bleu --eval-bleu-detok space --keep-last-epochs 5 --keep-best-checkpoints 5  --fixed-validation-seed 7 --ddp-backend=no_c10d \
   --share-all-embeddings --decoder-learned-pos --encoder-learned-pos  --optimizer adam --adam-betas "(0.9,0.98)" --lr 0.0005 \ 
   --lr-scheduler inverse_sqrt --stop-min-lr 1e-09 --warmup-updates 10000 --warmup-init-lr 1e-07 --apply-bert-init --weight-decay 0.01 \
   --fp16 --clip-norm 2.0 --max-update 300000  --task translation_lev --criterion nat_loss --arch nat_ctc_sd_ss --noise full_mask \ 
   --src-upsample-scale 2 --use-ctc-decoder --ctc-beam-size 1  --concat-yhat --concat-dropout 0.0  \ 
   --activation-fn gelu --dropout 0.1  --max-tokens 8192 --ss-ratio 0.3 --fixed-ss-ratio

Evaluation

fairseq-generate data-bin/wmt14.en-de_kd  --path PATH_TO_A_CHECKPOINT \
    --gen-subset test --task translation_lev --iter-decode-max-iter 0 \
    --iter-decode-eos-penalty 0 --beam 1 --remove-bpe --print-step --batch-size 100

Note: 1) Add --plain-ctc --model-overrides '{"ctc_beam_size": 1, "plain_ctc": True}' if it is CTC based; 2) Change the task to translation_glat if it is GLAT based.

Output

We in addition provide the output of CTC w/ DSLP, CTC w/ DSLP & Mixed Training, Vanilla NAT w/ DSLP, Vanilla NAT w/ DSLP with Mixed Training, GLAT w/ DSLP, and CMLM w/ DSLP for review purpose.

Model Reference Hypothesis
CTC w/ DSLP ref hyp
CTC w/ DSLP & Mixed Training ref hyp
Vanilla NAT w/ DSLP ref hyp
Vanilla NAT w/ DSLP & Mixed Training ref hyp
GLAT w/ DSLP ref hyp
CMLM w/ DSLP ref hyp

Note: The output is on WMT'14 EN-DE. The references are paired with hypotheses for each model.

Owner
Chenyang Huang
Stay hungry, stay foolish
Chenyang Huang
People Interaction Graph

Gihan Jayatilaka*, Jameel Hassan*, Suren Sritharan*, Janith Senananayaka, Harshana Weligampola, et. al., 2021. Holistic Interpretation of Public Scenes Using Computer Vision and Temporal Graphs to Id

University of Peradeniya : COVID Research Group 1 Aug 24, 2022
Pop-Out Motion: 3D-Aware Image Deformation via Learning the Shape Laplacian (CVPR 2022)

Pop-Out Motion Pop-Out Motion: 3D-Aware Image Deformation via Learning the Shape Laplacian (CVPR 2022) Jihyun Lee*, Minhyuk Sung*, Hyunjin Kim, Tae-Ky

Jihyun Lee 88 Nov 22, 2022
Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive losses

Self-supervised learning Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive loss

Arijit Das 2 Mar 26, 2022
Writeups for the challenges from DownUnderCTF 2021

cloud Challenge Author Difficulty Release Round Bad Bucket Blue Alder easy round 1 Not as Bad Bucket Blue Alder easy round 1 Lost n Found Blue Alder m

DownUnderCTF 161 Dec 31, 2022
Geometry-Free View Synthesis: Transformers and no 3D Priors

Geometry-Free View Synthesis: Transformers and no 3D Priors Geometry-Free View Synthesis: Transformers and no 3D Priors Robin Rombach*, Patrick Esser*

CompVis Heidelberg 293 Dec 22, 2022
DeepMReye: magnetic resonance-based eye tracking using deep neural networks

DeepMReye: magnetic resonance-based eye tracking using deep neural networks

73 Dec 21, 2022
Official pytorch implementation of Active Learning for deep object detection via probabilistic modeling (ICCV 2021)

Active Learning for Deep Object Detection via Probabilistic Modeling This repository is the official PyTorch implementation of Active Learning for Dee

NVIDIA Research Projects 130 Jan 06, 2023
Mosaic of Object-centric Images as Scene-centric Images (MosaicOS) for long-tailed object detection and instance segmentation.

MosaicOS Mosaic of Object-centric Images as Scene-centric Images (MosaicOS) for long-tailed object detection and instance segmentation. Introduction M

Cheng Zhang 27 Oct 12, 2022
The source code for the Cutoff data augmentation approach proposed in this paper: "A Simple but Tough-to-Beat Data Augmentation Approach for Natural Language Understanding and Generation".

Cutoff: A Simple Data Augmentation Approach for Natural Language This repository contains source code necessary to reproduce the results presented in

Dinghan Shen 49 Dec 22, 2022
Compositional and Parameter-Efficient Representations for Large Knowledge Graphs

NodePiece - Compositional and Parameter-Efficient Representations for Large Knowledge Graphs NodePiece is a "tokenizer" for reducing entity vocabulary

Michael Galkin 107 Jan 04, 2023
A keras-based real-time model for medical image segmentation (CFPNet-M)

CFPNet-M: A Light-Weight Encoder-Decoder Based Network for Multimodal Biomedical Image Real-Time Segmentation This repository contains the implementat

268 Nov 27, 2022
[CVPR 2022 Oral] Versatile Multi-Modal Pre-Training for Human-Centric Perception

Versatile Multi-Modal Pre-Training for Human-Centric Perception Fangzhou Hong1  Liang Pan1  Zhongang Cai1,2,3  Ziwei Liu1* 1S-Lab, Nanyang Technologic

Fangzhou Hong 96 Jan 03, 2023
Python Implementation of Chess Playing AI with variable difficulty

Chess AI with variable difficulty level implemented using the MiniMax AB-Pruning Algorithm

Ali Imran 7 Feb 20, 2022
Open source implementation of AceNAS: Learning to Rank Ace Neural Architectures with Weak Supervision of Weight Sharing

AceNAS This repo is the experiment code of AceNAS, and is not considered as an official release. We are working on integrating AceNAS as a built-in st

Yuge Zhang 6 Sep 07, 2022
Instance-based label smoothing for improving deep neural networks generalization and calibration

Instance-based Label Smoothing for Neural Networks Pytorch Implementation of the algorithm. This repository includes a new proposed method for instanc

Mohamed Maher 1 Aug 13, 2022
Train an RL agent to execute natural language instructions in a 3D Environment (PyTorch)

Gated-Attention Architectures for Task-Oriented Language Grounding This is a PyTorch implementation of the AAAI-18 paper: Gated-Attention Architecture

Devendra Chaplot 234 Nov 05, 2022
The Malware Open-source Threat Intelligence Family dataset contains 3,095 disarmed PE malware samples from 454 families

MOTIF Dataset The Malware Open-source Threat Intelligence Family (MOTIF) dataset contains 3,095 disarmed PE malware samples from 454 families, labeled

Booz Allen Hamilton 112 Dec 13, 2022
Mememoji - A facial expression classification system that recognizes 6 basic emotions: happy, sad, surprise, fear, anger and neutral.

a project built with deep convolutional neural network and ❤️ Table of Contents Motivation The Database The Model 3.1 Input Layer 3.2 Convolutional La

Jostine Ho 761 Dec 05, 2022
Using OpenAI's CLIP to upscale and enhance images

CLIP Upscaler and Enhancer Using OpenAI's CLIP to upscale and enhance images Based on nshepperd's JAX CLIP Guided Diffusion v2.4 Sample Results Viewpo

Tripp Lyons 5 Jun 14, 2022
Source code for our paper "Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures"

Molecular Mechanics-Driven Graph Neural Network with Multiplex Graph for Molecular Structures Code for the Multiplex Molecular Graph Neural Network (M

shzhang 59 Dec 10, 2022