The Codebase for Causal Distillation for Language Models.

Overview

Python 3.7 License CC BY-NC

Causal Distillation for Language Models

Zhengxuan Wu*,Atticus Geiger*, Josh Rozner, Elisa Kreiss, Hanson Lu, Thomas Icard, Christopher Potts, Noah D. Goodman

The is an implementation of our preprint Causal Distillation for Language Models. The standard approach to distillation trains a student model against two objectives: a task-specific objective (e.g., language modeling) and an imitation objective that encourages the hidden states of the student model to be similar to those of the larger teacher model. In this paper, we show that it is beneficial to augment distillation with a third objective that encourages the student to imitate the causal computation process of the teacher through interchange intervention training (IIT).

We fork our main codebase from the Huggingface Distillation Interface.

Release Notes

12/02/2021 Our paper on Interchange Intervention Training (IIT) is released! Read this more formal definition of the method.
12/06/2021 Released the causal distillation codebase with the preprint.
12/06/2021 Released evaluation results on distilled tiny-BERT (3 layers) with the Wiki-Text 103M dataset.
⬜️ Released evaluation results on causal-distilled tiny-BERT (3 layers) with the Wiki-Text 103M + BookCorpus dataset.
⬜️ Released evaluation results on causal-distilled BERT (6 layers) with the Wiki-Text 103M + BookCorpus dataset.
⬜️ Released more ablation studies.
⬜️ Released causal-distilled tiny-BERT (3 layers) model files.
⬜️ Released causal-distilled BERT (6 layers) model files.

If you experience any issues or have suggestions, please contact me either thourgh the issues page or at [email protected].

Benchmark Results

Here are the results on the dev sets of GLUE:

Model Average-score CoLA MNLI MRPC QNLI QQP RTE SST-2 STS-B WNLI
DistilBERT (3 layers) 67.81 22.8 71.6 78.2 82.1 84.3 55.4 86.5 56.7 24.2
CausalBERT (3 layers) 69.71 25.0 72.9 78.6 83.1 84.9 55.4 86.9 66.5 21.5

1 Average-score computed without WNLI.

Main Contents

Citation

If you use this repository, please cite the following two papers: paper for interchange intervention training, and paper for the our distillation method.

  @article{geiger-etal-2021-iit,
        title={Inducing Causal Structure for Interpretable Neural Networks}, 
        author={Geiger, Atticus and Wu, Zhengxuan and Lu, Hanson and Rozner, Josh and Kreiss, Elisa and Icard, Thomas and Goodman, Noah D. and Potts, Christopher},
        year={2021},
        eprint={2112.00826},
        archivePrefix={arXiv},
        primaryClass={cs.LG}
  }

  @article{wu-etal-2021-distill,
        title={Causal Distillation for Language Models}, 
        author={Wu, Zhengxuan and Geiger, Atticus and Rozner, Josh and Kreiss, Elisa and Lu, Hanson and Icard, Thomas and Potts, Christopher and Goodman, Noah D.},
        year={2021},
        eprint={2112.02505},
        archivePrefix={arXiv},
        primaryClass={cs.CL}
  }

Requirements

  • Python 3.6 or 3.7 are supported.
  • Pytorch Version: 1.9.0
  • Transfermers Version: 4.11.3
  • Datasets Version: Version: 1.8.0
  • We have performed experiments on Titan V GPU. We assume 12GB of GPU memory (more memory can expedite training).
  • Since we build our codebase off the Huggingface Distillation Interface, please review their doc for requirements.

Dataset

Following the Huggingface Distillation Interface, we need to pre-process the datasets before we do distillation. You can refer to their repo for details. We adapt their pre-processing scripts, and update with a few improvements. For example, we can now binarize datasets from the Dataset Hub from huggingface directly.

# preprocessing from disk
python script/binarized_data.py \
--file_path ../../bert-mid-tuning/data-files/wikitext-15M \
--split train \
--field_name text \
--max_parsing_example 1000 \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file ./data/binarized_text

# preprocessing from huggingface.
python scripts/binarized_data.py \
--dataset_name bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/

python scripts/binarized_data.py \
--dataset_name wikitext \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file wikitext-dataset/binarized_text \
--cache_dir ./distill_cache/

python scripts/binarized_data.py \
--dataset_name wikitext+bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file wikitext+bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/

# helper scripts to combine two binarized data files
python scripts/data_combinator.py \
--file_path_left ./bookcorpus-dataset/binarized_text.train.bert-base-uncased.pickle \
--file_path_right ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle \
--split train \
--tokenizer_name bert-base-uncased \
--dump_file wikitext+bookcorpus-dataset/binarized_text

# multiprocessing preprocessor.
python scripts/binarized_data.py \
--dataset_name bookcorpus \
--split train \
--field_name text \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file bookcorpus-dataset/binarized_text \
--cache_dir ./distill_cache/ \
--fast_process \
--preprocessing_num_workers 48

After you get the datasets ready, you need to generate token counts as well.

python scripts/token_counts.py \
--data_file data/binarized_text.train.bert-base-uncased.pickle \
--token_counts_dump data/binarized_text.train.token_counts.bert-base-uncased.pickle \
--vocab_size 30522

Distillation

Before training, we recommand you to initialize your student model with weights extracted from the teacher model.

python scripts/extract_distilbert.py \
--model_type bert \
--model_name bert-base-uncased \
--dump_checkpoint ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--num_layers 3

Now, here is an example for you to distill with our causal distillation objective or without,

CUDA_VISIBLE_DEVICES=9,4 python causal_train.py \
--force \
--n_gpu 2 \
--is_wandb \
--log_interval 10 \
--student_type distilbert \
--student_config ./training_configs/distilbert-base-uncased-small.json \
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--teacher_type bert \
--teacher_name bert-base-uncased \
--neuron_mapping ./training_configs/single_middle.nm \
--mlm --alpha_ce 0.25 --alpha_mlm 0.25 --alpha_cos 0.25 --alpha_clm 0.0 --alpha_causal 0.25 \
--freeze_pos_embs \
--dump_path ./results/ \
--data_file ./wikitext-15M/binarized_text.train.bert-base-uncased.pickle \
--token_counts ./wikitext-15M/binarized_text.train.token_counts.bert-base-uncased.pickle \
--seed 42 \
--gradient_accumulation_steps 50 \
--n_epoch 3 \
--batch_size 5

CUDA_VISIBLE_DEVICES=0,1,2,3 python causal_train.py \
--force \
--n_gpu 4 \
--is_wandb \
--log_interval 10 \
--student_type distilbert \
--student_config ./training_configs/distilbert-base-uncased-small.json \
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--teacher_type bert \
--teacher_name bert-base-uncased \
--neuron_mapping ./training_configs/single_middle.nm \
--mlm --alpha_ce 0.33 --alpha_mlm 0.33 --alpha_cos 0.33 --alpha_clm 0.0 --alpha_causal 0.00 \
--freeze_pos_embs \
--dump_path ./results/ \
--data_file ./wikitext-15M/binarized_text.train.bert-base-uncased.pickle \
--token_counts ./wikitext-15M/binarized_text.train.token_counts.bert-base-uncased.pickle \
--seed 42 \
--gradient_accumulation_steps 124 \
--n_epoch 6 \
--batch_size 4

Note that you can simply turn our causal distillation objective on/off through setting the arguments.

Evaluation

After you get your distilled models, you need to fine-tune them and evaluate them with downstream tasks. We provide you all the scripts you need to run.

MLM Evaluation

CUDA_VISIBLE_DEVICES=5 python run_mlm.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-15M_seed_42_mlm_True_ce_0.25_mlm_0.25_cos_0.25_causal_0.25_nm_single_multilayer/ \
--dataset_dir ../../bert-mid-tuning/data-files/wikitext-15M/ \
--tokenizer_name bert-base-uncased \
--do_eval \
--output_dir /tmp/test-mlm \
--cache_dir ./distill_cache/

GLUE Evaluation

CUDA_VISIBLE_DEVICES=5,7,8,9 python run_glue.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle/ \
--tokenizer_name bert-base-uncased \
--task_name sst2 \
--do_train \
--do_eval \
--max_seq_length 128 \
--per_device_train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--output_dir ./results/ \
--save_total_limit 1 \
--cache_dir ./distill_cache/

CoNLL Evaluation

CUDA_VISIBLE_DEVICES=2,3,7,8 python run_ner.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle_crossway_False/ \
--tokenizer_name bert-base-uncased \
--dataset_name conll2003 \
--do_train \
--do_eval \
--output_dir ./ner_results/ \
--save_total_limit 1 \
--cache_dir ./distill_cache/

SQuAD Evaluation

CUDA_VISIBLE_DEVICES=2,3,7,8 python run_qa.py \
--model_name_or_path ./results/s_distilbert_t_bert_data_wikitext-dataset_seed_42_mlm_True_ce_0.33_mlm_0.33_cos_0.33_causal_0.0_nm_single_middle_crossway_False/ \
--tokenizer_name bert-base-uncased \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size 12 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--save_total_limit 1 \
--output_dir ./qa_results/
Retrieval.pytorch - The code we used in [2020 DIGIX]

Retrieval.pytorch - The code we used in [2020 DIGIX]

Guo-Hua Wang 2 Feb 07, 2022
Synthetic Humans for Action Recognition, IJCV 2021

SURREACT: Synthetic Humans for Action Recognition from Unseen Viewpoints Gül Varol, Ivan Laptev and Cordelia Schmid, Andrew Zisserman, Synthetic Human

Gul Varol 59 Dec 14, 2022
A trashy useless Latin programming language written in python.

Codigum! The first programming langage in latin! (please keep your eyes closed when if you read the source code) It is pretty useless though. Document

Bic 2 Oct 25, 2021
This is the code repository for the paper A hierarchical semantic segmentation framework for computer-vision-based bridge column damage detection

Bridge-damage-segmentation This is the code repository for the paper A hierarchical semantic segmentation framework for computer-vision-based bridge c

Jingxiao Liu 5 Dec 07, 2022
SOTA model in CIFAR10

A PyTorch Implementation of CIFAR Tricks 调研了CIFAR10数据集上各种trick,数据增强,正则化方法,并进行了实现。目前项目告一段落,如果有更好的想法,或者希望一起维护这个项目可以提issue或者在我的主页找到我的联系方式。 0. Requirement

PJDong 58 Dec 21, 2022
Adjust Decision Boundary for Class Imbalanced Learning

Adjusting Decision Boundary for Class Imbalanced Learning This repository is the official PyTorch implementation of WVN-RS, introduced in Adjusting De

Peyton Byungju Kim 16 Jan 04, 2023
Full Transformer Framework for Robust Point Cloud Registration with Deep Information Interaction

Full Transformer Framework for Robust Point Cloud Registration with Deep Information Interaction. arxiv This repository contains python scripts for tr

12 Dec 12, 2022
Implementation for HFGI: High-Fidelity GAN Inversion for Image Attribute Editing

HFGI: High-Fidelity GAN Inversion for Image Attribute Editing High-Fidelity GAN Inversion for Image Attribute Editing Update: We released the inferenc

Tengfei Wang 371 Dec 30, 2022
PyTorch implementation of DirectCLR from paper Understanding Dimensional Collapse in Contrastive Self-supervised Learning

DirectCLR DirectCLR is a simple contrastive learning model for visual representation learning. It does not require a trainable projector as SimCLR. It

Meta Research 49 Dec 21, 2022
My tensorflow implementation of "A neural conversational model", a Deep learning based chatbot

Deep Q&A Table of Contents Presentation Installation Running Chatbot Web interface Results Pretrained model Improvements Upgrade Presentation This wor

Conchylicultor 2.9k Dec 28, 2022
Codes for building and training the neural network model described in Domain-informed neural networks for interaction localization within astroparticle experiments.

Domain-informed Neural Networks Codes for building and training the neural network model described in Domain-informed neural networks for interaction

DIDACTS 0 Dec 13, 2021
KDD CUP 2020 Automatic Graph Representation Learning: 1st Place Solution

KDD CUP 2020: AutoGraph Team: aister Members: Jianqiang Huang, Xingyuan Tang, Mingjian Chen, Jin Xu, Bohang Zheng, Yi Qi, Ke Hu, Jun Lei Team Introduc

96 May 30, 2022
Machine Learning automation and tracking

The Open-Source MLOps Orchestration Framework MLRun is an open-source MLOps framework that offers an integrative approach to managing your machine-lea

873 Jan 04, 2023
Pytorch version of VidLanKD: Improving Language Understanding viaVideo-Distilled Knowledge Transfer

VidLanKD Implementation of VidLanKD: Improving Language Understanding via Video-Distilled Knowledge Transfer by Zineng Tang, Jaemin Cho, Hao Tan, Mohi

Zineng Tang 54 Dec 20, 2022
FusionNet: A deep fully residual convolutional neural network for image segmentation in connectomics

FusionNet_Pytorch FusionNet: A deep fully residual convolutional neural network for image segmentation in connectomics Requirements Pytorch 0.1.11 Pyt

Choi Gunho 102 Dec 13, 2022
Continual Learning of Electronic Health Records (EHR).

Continual Learning of Longitudinal Health Records Repo for reproducing the experiments in Continual Learning of Longitudinal Health Records (2021). Re

Jacob 7 Oct 21, 2022
Safe Local Motion Planning with Self-Supervised Freespace Forecasting, CVPR 2021

Safe Local Motion Planning with Self-Supervised Freespace Forecasting By Peiyun Hu, Aaron Huang, John Dolan, David Held, and Deva Ramanan Citing us Yo

Peiyun Hu 90 Dec 01, 2022
A curated list of long-tailed recognition resources.

Awesome Long-tailed Recognition A curated list of long-tailed recognition and related resources. Please feel free to pull requests or open an issue to

Zhiwei ZHANG 542 Jan 01, 2023
ShuttleNet: Position-aware Fusion of Rally Progress and Player Styles for Stroke Forecasting in Badminton (AAAI 2022)

ShuttleNet: Position-aware Rally Progress and Player Styles Fusion for Stroke Forecasting in Badminton (AAAI 2022) Official code of the paper ShuttleN

Wei-Yao Wang 11 Nov 30, 2022
Image morphing without reference points by applying warp maps and optimizing over them.

Differentiable Morphing Image morphing without reference points by applying warp maps and optimizing over them. Differentiable Morphing is machine lea

Alex K 380 Dec 19, 2022