使用pytorch+transformers复现了SimCSE论文中的有监督训练和无监督训练方法

Related tags

Text Data & NLPSimCSE
Overview

SimCSE复现

项目描述

SimCSE是一种简单但是很巧妙的NLP对比学习方法,创新性地引入Dropout的方式,对样本添加噪声,从而达到对正样本增强的目的。 该框架的训练目的为:对于batch中的每个样本,拉近其与正样本之间的距离,拉远其与负样本之间的距离,使得模型能够在大规模无监督语料(也可以使用有监督的语料)中学习到文本相似关系。 详见论文:Simple Contrastive Learning of Sentence EmbeddingsSimCSE官方代码仓库

本项目使用pytorch+transformers复现了SimCSE论文中的有监督训练和无监督训练方法,并且在STS-B数据集上进行消融实验,评价指标为Spearman相关系数,预训练模型为Bert-base-uncased, 验证了SimCSE的有效性。在STS-B数据集上,有监督训练和无监督训练的复现效果如下表。

在无监督训练中,dropout=0.1时,复现效果比原文略差,但也比较接近。当dropout=0.2时,复现效果比原文略高。 ** 但在有监督训练中,不知是否由于batch size过小(原论文使用512),复现效果与论文的效果相差较远,后续会进行排查。 **

训练方法 learning rate batch size dropout Spearman’s correlation
原论文 无监督 3e-5 64 0.1 0.763
复现 无监督 3e-5 64 0.2 0.771
复现 无监督 3e-5 64 0.1 0.748
原论文 有监督 5e-5 512 0.1 0.816
复现 有监督 5e-5 64 0.1 0.764

运行环境

python==3.6、transformers==3.1.0、torch==1.6.0

项目结构

  • data:存放训练数据
    • stsbenchmark:STS-B数据集
      • sts-dev.csv:STS-B验证集
      • sts-test.csv:STS-B验测试集
    • nli_for_simcse.csv:数量275601为的NLI数据集
    • wiki1m_for_simcse.txt:维基百科上获取的100w的文本
  • output:输出目录
  • pretrain_model:预训练模型存放位置
  • script:脚本存放位置。
  • dataset.py
  • model.py:模型代码,包含有监督和无监督损失函数的计算方式
  • train.py:训练代码

使用方法

Quick Start

下载训练数据:

bash script/download_nli.sh
bash script/download_wiki.sh

无监督训练,运行脚本

bash script/run_unsup_train.sh

有监督训练,运行脚本

bash script/run_sup_train.sh

实验

无监督训练

从前四条实验数据中可以看到,较大的batch size在一定程度上可以增加模型的泛化性。

dropout为0.2的时候,训练效果比0.1与0.3更好,有可能dropout=0.1加入的噪声过小,而dropout=0.3加入的噪声过大,增强得到的样本与原始样本差异较大。

learning rate batch size dropout 在哪一步得到best checkpoint 验证集上的得分 测试集上的得分
3e-5 256 0.1 6000 0.800 0.761
3e-5 128 0.1 4200 0.799 0.747
3e-5 64 0.1 10900 0.803 0.748
3e-5 32 0.1 21300 0.787 0.714
3e-5 64 0.2 11200 0.811 0.771
3e-5 64 0.3 6300 0.781 0.745
1e-5 64 0.1 16400 0.798 0.751

有监督训练

有监督实验的复现结果未达到预期,超参数相同时,在验证集上的得分略高于无监督,但是在测试集上,得分基本没有差异。增大有监督训练的学习率,有监督的训练的得分略高于无监督训练, 但还是与论文声称的0.816相差较远,原论文使用512的batch size, 不知是否由于batch size的设置有关,后续会对有监督的训练代码进一步排查。

不过从训练曲线可以看到,有监督训练的收敛速度明显快于无监督训练,这也符合我们的认知。

训练方法 learning rate batch size dropout 在哪一步得到best checkpoint 验证集上的得分 测试集上的得分
无监督 3e-5 64 0.1 10900 0.803 0.748
有监督 3e-5 64 0.1 200 0.810 0.748
有监督 5e-5 64 0.1 2300 0.809 0.764
有监督 3e-5 32 0.1 200 0.808 0.743
有监督 5e-5 32 0.1 200 0.806 0.746

无监督训练过程中,验证集得分的变化曲线: avatar

有监督训练过程中,验证集得分的变化曲线: avatar

REFERENCE

TODO

  • 排查有监督学习的效果不符合预期的原因
Pattern Matching in Python

Pattern Matching finalmente chega no Python 3.10. E daí? "Pattern matching", ou "correspondência de padrões" como é conhecido no Brasil. Algumas pesso

Fabricio Werneck 6 Feb 16, 2022
SimCTG - A Contrastive Framework for Neural Text Generation

A Contrastive Framework for Neural Text Generation Authors: Yixuan Su, Tian Lan,

Yixuan Su 345 Jan 03, 2023
Toward Model Interpretability in Medical NLP

Toward Model Interpretability in Medical NLP LING380: Topics in Computational Linguistics Final Project James Cross ( 1 Mar 04, 2022

TPlinker for NER 中文/英文命名实体识别

本项目是参考 TPLinker 中HandshakingTagging思想,将TPLinker由原来的关系抽取(RE)模型修改为命名实体识别(NER)模型。

GodK 113 Dec 28, 2022
An open source framework for seq2seq models in PyTorch.

pytorch-seq2seq Documentation This is a framework for sequence-to-sequence (seq2seq) models implemented in PyTorch. The framework has modularized and

International Business Machines 1.4k Jan 02, 2023
The (extremely) naive sentiment classification function based on NBSVM trained on wisesight_sentiment

thai_sentiment The naive sentiment classification function based on NBSVM trained on wisesight_sentiment วิธีติดตั้ง pip install thai_sentiment==0.1.3

Charin 7 Dec 08, 2022
All the code I wrote for Overwatch-related projects that I still own the rights to.

overwatch_shit.zip This is (eventually) going to contain all the software I wrote during my five-year imprisonment stay playing Overwatch. I'll be add

zkxjzmswkwl 2 Dec 31, 2021
Simple NLP based project without any use of AI

Simple NLP based project without any use of AI

Shripad Rao 1 Apr 26, 2022
NLP: SLU tagging

NLP: SLU tagging

北海若 3 Jan 14, 2022
String Gen + Word Checker

Creates random strings and checks if any of them are a real words. Mostly a waste of time ngl but it is cool to see it work and the fact that it can generate a real random word within10sec

1 Jan 06, 2022
An official repository for tutorials of Probabilistic Modelling and Reasoning (2021/2022) - a University of Edinburgh master's course.

PMR computer tutorials on HMMs (2021-2022) This is a repository for computer tutorials of Probabilistic Modelling and Reasoning (2021/2022) - a Univer

Vaidotas Šimkus 10 Dec 06, 2022
Code for our paper "Transfer Learning for Sequence Generation: from Single-source to Multi-source" in ACL 2021.

TRICE: a task-agnostic transferring framework for multi-source sequence generation This is the source code of our work Transfer Learning for Sequence

THUNLP-MT 9 Jun 27, 2022
Implementation of the Hybrid Perception Block and Dual-Pruned Self-Attention block from the ITTR paper for Image to Image Translation using Transformers

ITTR - Pytorch Implementation of the Hybrid Perception Block (HPB) and Dual-Pruned Self-Attention (DPSA) block from the ITTR paper for Image to Image

Phil Wang 17 Dec 23, 2022
Simple, Pythonic, text processing--Sentiment analysis, part-of-speech tagging, noun phrase extraction, translation, and more.

TextBlob: Simplified Text Processing Homepage: https://textblob.readthedocs.io/ TextBlob is a Python (2 and 3) library for processing textual data. It

Steven Loria 8.4k Dec 26, 2022
BERTopic is a topic modeling technique that leverages 🤗 transformers and c-TF-IDF to create dense clusters allowing for easily interpretable topics whilst keeping important words in the topic descriptions

BERTopic BERTopic is a topic modeling technique that leverages 🤗 transformers and c-TF-IDF to create dense clusters allowing for easily interpretable

Maarten Grootendorst 3.6k Jan 07, 2023
Simple Text-Generator with OpenAI gpt-2 Pytorch Implementation

GPT2-Pytorch with Text-Generator Better Language Models and Their Implications Our model, called GPT-2 (a successor to GPT), was trained simply to pre

Tae-Hwan Jung 775 Jan 08, 2023
test

Lidar-data-decode In this project, you can decode your lidar data frame(pcap file) and make your own datasets(test dataset) in Windows without any hug

46 Dec 05, 2022
The tool to make NLP datasets ready to use

chazutsu photo from Kaikado, traditional Japanese chazutsu maker chazutsu is the dataset downloader for NLP. import chazutsu r = chazutsu.data

chakki 243 Dec 29, 2022
Part of Speech Tagging using Hidden Markov Model (HMM) POS Tagger and Brill Tagger

Part of Speech Tagging using Hidden Markov Model (HMM) POS Tagger and Brill Tagger In this project, our aim is to tune, compare, and contrast the perf

Chirag Daryani 0 Dec 25, 2021
Official source for spanish Language Models and resources made @ BSC-TEMU within the "Plan de las Tecnologías del Lenguaje" (Plan-TL).

Spanish Language Models 💃🏻 A repository part of the MarIA project. Corpora 📃 Corpora Number of documents Number of tokens Size (GB) BNE 201,080,084

Plan de Tecnologías del Lenguaje - Gobierno de España 203 Dec 20, 2022