Repository for fine-tuning Transformers ๐Ÿค— based seq2seq speech models in JAX/Flax.

Overview

Seq2Seq Speech in JAX

A JAX/Flax repository for combining a pre-trained speech encoder model (e.g. Wav2Vec2, HuBERT, WavLM) with a pre-trained text decoder model (e.g. GPT2, Bart) to yield a Speech Sequence-to-Sequence (Seq2Seq) model for automatic speech recognition.

The script run_flax_speech_recognition_seq2seq.py can be used to fine-tune a Speech Seq2Seq model on one of the official speech recognition datasets or a custom dataset. It makes use of the pmap JAX operator to provide model parallelism accross GPU/TPU devices.

The modelling files are based very heavily on those from Hugging Face Transformers ๐Ÿค— . This is a standalone repository to enable rapid prototyping and involvement with the community. The final modelling files and training script will be merged into Transformers ๐Ÿค— to be used with the rest of the open-source library. The final system weights will be made publicly available at huggingface.co ๐Ÿš€

Seq2SeqModel Figure 1: Speech-encoder text-decoder style Seq2Seq model.

Example Usage

To instantiate a Wav2Vec2-2-Bart model with the FlaxSpeechEncoderDecoderModel framework, run the following Python script inside the cloned repo:

from transformers import AutoFeatureExtractor, AutoTokenizer
from models.modeling_flax_speech_encoder_decoder import FlaxSpeechEncoderDecoderModel
import numpy as np

# checkpoints to leverage
encoder_id = "facebook/wav2vec2-large-lv60"
decoder_id = "facebook/bart-large"

model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_id, decoder_id, encoder_add_adapter=True, decoder_from_pt=True)

model.config.decoder_start_token_id = model.config.decoder.bos_token_id
model.config.pad_token_id = model.config.decoder.pad_token_id
model.config.eos_token_id = model.config.decoder.eos_token_id
model.config.use_cache = False
model.config.processor_class = "Wav2Vec2Processor"

# check if generation works
out = model.generate(np.ones((1, 2000)))

model.save_pretrained("./")

feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
feature_extractor.save_pretrained("./")
tokenizer = AutoTokenizer.from_pretrained(decoder_id)
tokenizer.save_pretrained("./")

To train the model on Librispeech ASR in default precision, run the bash script provided below:

#!/usr/bin/env bash
python run_flax_speech_recognition_seq2seq.py \
        --dataset_name="librispeech_asr" \
        --model_name_or_path="./" \
        --dataset_config_name="clean" \
        --train_split_name="train.100" \
        --eval_split_name="validation" \
        --output_dir="./" \
        --preprocessing_num_workers="16" \
        --length_column_name="input_length" \
        --overwrite_output_dir \
        --num_train_epochs="5" \
        --per_device_train_batch_size="2" \
        --per_device_eval_batch_size="2" \
        --gradient_accumulation_steps="1" \
        --logging_steps="25" \
        --max_duration_in_seconds="15" \
        --max_target_length="128" \
        --generation_max_length="40" \
        --generation_num_beams="1" \
        --learning_rate="1e-4" \
        --warmup_steps="500" \
        --text_column_name="text" \
        --save_total_limit="1" \
        --freeze_feature_encoder \
        --predict_with_generate \
        --do_lower_case \
        --do_eval \
        --do_train
Owner
Sanchit Gandhi
Open-Source Speech @huggingface
Sanchit Gandhi
A Plover python dictionary allowing for consistent symbol input with specification of attachment and capitalisation in one stroke.

Emily's Symbol Dictionary Design This dictionary was created with the following goals in mind: Have a consistent method to type (pretty much) every sy

Emily 68 Jan 07, 2023
Implemented shortest-circuit disambiguation, maximum probability disambiguation, HMM-based lexical annotation and BiLSTM+CRF-based named entity recognition

Implemented shortest-circuit disambiguation, maximum probability disambiguation, HMM-based lexical annotation and BiLSTM+CRF-based named entity recognition

0 Feb 13, 2022
State-of-the-art NLP through transformer models in a modular design and consistent APIs.

Trapper (Transformers wRAPPER) Trapper is an NLP library that aims to make it easier to train transformer based models on downstream tasks. It wraps h

Open Business Software Solutions 42 Sep 21, 2022
Pytorch-version BERT-flow: One can apply BERT-flow to any PLM within Pytorch framework.

Pytorch-version BERT-flow: One can apply BERT-flow to any PLM within Pytorch framework.

Ubiquitous Knowledge Processing Lab 59 Dec 01, 2022
Extract city and country mentions from Text like GeoText without regex, but FlashText, a Aho-Corasick implementation.

flashgeotext โšก ๐ŸŒ Extract and count countries and cities (+their synonyms) from text, like GeoText on steroids using FlashText, a Aho-Corasick impleme

Ben 57 Dec 16, 2022
A collection of models for image - text generation in ACM MM 2021.

Bi-directional Image and Text Generation UMT-BITG (image & text generator) Unifying Multimodal Transformer for Bi-directional Image and Text Generatio

Multimedia Research 63 Oct 30, 2022
Faster, modernized fork of the language identification tool langid.py

py3langid py3langid is a fork of the standalone language identification tool langid.py by Marco Lui. Original license: BSD-2-Clause. Fork license: BSD

Adrien Barbaresi 12 Nov 05, 2022
Journey is a NLP-Powered Developer assistant

Journey Journey is a NLP-Powered Developer assistant Using on the powerful Natural Language Processing library Mindmeld, this projects aims to assist

Christian Eilers 21 Dec 11, 2022
Hostapd-mac-tod-acl - Setup a hostapd AP with MAC ToD ACL

A brief explanation This script provides a quick way to setup a Time-of-day (Tod

2 Feb 03, 2022
Deduplication is the task to combine different representations of the same real world entity.

Deduplication is the task to combine different representations of the same real world entity. This package implements deduplication using active learning. Active learning allows for rapid training wi

63 Nov 17, 2022
New Modeling The Background CodeBase

Modeling the Background for Incremental Learning in Semantic Segmentation This is the updated official PyTorch implementation of our work: "Modeling t

Fabio Cermelli 9 Dec 28, 2022
German Text-To-Speech Engine using Tacotron and Griffin-Lim

jotts JoTTS is a German text-to-speech engine using tacotron and griffin-lim. The synthesizer model has been trained on my voice using Tacotron1. Due

padmalcom 6 Aug 28, 2022
Nested Named Entity Recognition for Chinese Biomedical Text

CBio-NAMER CBioNAMER (Nested nAMed Entity Recognition for Chinese Biomedical Text) is our method used in CBLUE (Chinese Biomedical Language Understand

8 Dec 25, 2022
This repository is home to the Optimus data transformation plugins for various data processing needs.

Transformers Optimus's transformation plugins are implementations of Task and Hook interfaces that allows execution of arbitrary jobs in optimus. To i

Open Data Platform 37 Dec 14, 2022
OCR์„ ์ด์šฉํ•˜์—ฌ ์ธ์›์ˆ˜๋ฅผ ์ธ์‹ ํ›„ ์คŒ์„ Kill ํ•ด์ค๋‹ˆ๋‹ค

How To Use killtheZoom-2.0 Windows 0. https://joyhong.tistory.com/79 ์ด ๊ธ€์„ ๋ณด๋ฉด์„œ tesseract๋ฅผ C:\Program Files\Tesseract-OCR ๊ฒฝ๋กœ๋กœ ์„ค์น˜ํ•ด์ฃผ์„ธ์š”(ํ•œ๊ตญ์–ด ์–ธ์–ด ์ถ”๊ฐ€ ํ•„์š”) ์ƒ๋‹จ์˜ ์ดˆ

๊น€์ •์ธ 9 Sep 13, 2021
Implementation of Fast Transformer in Pytorch

Fast Transformer - Pytorch Implementation of Fast Transformer in Pytorch. This only work as an encoder. Yannic video AI Epiphany Install $ pip install

Phil Wang 167 Dec 27, 2022
BMInf (Big Model Inference) is a low-resource inference package for large-scale pretrained language models (PLMs).

BMInf (Big Model Inference) is a low-resource inference package for large-scale pretrained language models (PLMs).

OpenBMB 377 Jan 02, 2023
FastFormers - highly efficient transformer models for NLU

FastFormers FastFormers provides a set of recipes and methods to achieve highly efficient inference of Transformer models for Natural Language Underst

Microsoft 678 Jan 05, 2023
Ray-based parallel data preprocessing for NLP and ML.

Wrangl Ray-based parallel data preprocessing for NLP and ML. pip install wrangl # for latest pip install git+https://github.com/vzhong/wrangl See exa

Victor Zhong 33 Dec 27, 2022
A BERT-based reverse dictionary of Korean proverbs

Wisdomify A BERT-based reverse-dictionary of Korean proverbs. ๊น€์œ ๋นˆ : ๋ชจ๋ธ๋ง / ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘ / ํ”„๋กœ์ ํŠธ ์„ค๊ณ„ / back-end ๊น€์ข…์œค : ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘ / ํ”„๋กœ์ ํŠธ ์„ค๊ณ„ / front-end / back-end ์ž„์šฉ

94 Dec 08, 2022