BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

Overview

key_visual

BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

This is a demo implementation of BYOL for Audio (BYOL-A), a self-supervised learning method for general-purpose audio representation, includes:

  • Training code that can train models with arbitrary audio files.
  • Evaluation code that can evaluate trained models with downstream tasks.
  • Pretrained weights.

If you find BYOL-A useful in your research, please use the following BibTeX entry for citation.

@misc{niizumi2021byol-a,
      title={BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation}, 
      author={Daisuke Niizumi and Daiki Takeuchi and Yasunori Ohishi and Noboru Harada and Kunio Kashino},
      booktitle = {2021 International Joint Conference on Neural Networks, {IJCNN} 2021},
      year={2021},
      eprint={2103.06695},
      archivePrefix={arXiv},
      primaryClass={eess.AS}
}

Getting Started

  1. Download external source files, and apply a patch. Our implementation uses the following.

    curl -O https://raw.githubusercontent.com/lucidrains/byol-pytorch/2aa84ee18fafecaf35637da4657f92619e83876d/byol_pytorch/byol_pytorch.py
    patch < byol_a/byol_pytorch.diff
    mv byol_pytorch.py byol_a
    curl -O https://raw.githubusercontent.com/daisukelab/general-learning/7b31d31637d73e1a74aec3930793bd5175b64126/MLP/torch_mlp_clf.py
    mv torch_mlp_clf.py utils
  2. Install PyTorch 1.7.1, torchaudio, and other dependencies listed on requirements.txt.

Evaluating BYOL-A Representations

Downstream Task Evaluation

The following steps will perform a downstream task evaluation by linear-probe fashion. This is an example with SPCV2; Speech commands dataset v2.

  1. Preprocess metadata (.csv file) and audio files, processed files will be stored under a folder work.

    # usage: python -m utils.preprocess_ds <downstream task> <path to its dataset>
    python -m utils.preprocess_ds spcv2 /path/to/speech_commands_v0.02
  2. Run evaluation. This will convert all .wav audio to representation embeddings first, train a lineaer layer network, then calculate accuracy as a result.

    python evaluate.py pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth spcv2

You can also run an evaluation multiple times and take an average result. Following will evaluate on UrbanSound8K with a unit audio duration of 4.0 seconds, for 10 times.

# usage: python evaluate.py <your weight> <downstream task> <unit duration sec.> <# of iteration>
python evaluate.py pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth us8k 4.0 10

Evaluating Representations In Your Tasks

This is an example to calculate a feature vector for an audio sample.

from byol_a.common import *
from byol_a.augmentations import PrecomputedNorm
from byol_a.models import AudioNTT2020


device = torch.device('cuda')
cfg = load_yaml_config('config.yaml')
print(cfg)

# Mean and standard deviation of the log-mel spectrogram of input audio samples, pre-computed.
# See calc_norm_stats in evaluate.py for your reference.
stats = [-5.4919195,  5.0389895]

# Preprocessor and normalizer.
to_melspec = torchaudio.transforms.MelSpectrogram(
    sample_rate=cfg.sample_rate,
    n_fft=cfg.n_fft,
    win_length=cfg.win_length,
    hop_length=cfg.hop_length,
    n_mels=cfg.n_mels,
    f_min=cfg.f_min,
    f_max=cfg.f_max,
)
normalizer = PrecomputedNorm(stats)

# Load pretrained weights.
model = AudioNTT2020(d=cfg.feature_d)
model.load_weight('pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth', device)

# Load your audio file.
wav, sr = torchaudio.load('work/16k/spcv2/one/00176480_nohash_0.wav') # a sample from SPCV2 for now
assert sr == cfg.sample_rate, "Let's convert the audio sampling rate in advance, or do it here online."

# Convert to a log-mel spectrogram, then normalize.
lms = normalizer((to_melspec(wav) + torch.finfo(torch.float).eps).log())

# Now, convert the audio to the representation.
features = model(lms.unsqueeze(0))

Training From Scratch

You can also train models. Followings are an example of training on FSD50K.

  1. Convert all samples to 16kHz. This will convert all FSD50K files to a folder work/16k/fsd50k while preserving folder structure.

    python -m utils.convert_wav /path/to/fsd50k work/16k/fsd50k
  2. Start training, this example trains with all development set audio samples from FSD50K.

    python train.py work/16k/fsd50k/FSD50K.dev_audio

Refer to Table VI on our paper for the performance of a model trained on FSD50K.

Pretrained Weights

We include 3 pretrained weights of our encoder network.

Method Dim. Filename NSynth US8K VoxCeleb1 VoxForge SPCV2/12 SPCV2 Average
BYOL-A 512-d AudioNTT2020-BYOLA-64x96d512.pth 69.1% 78.2% 33.4% 83.5% 86.5% 88.9% 73.3%
BYOL-A 1024-d AudioNTT2020-BYOLA-64x96d1024.pth 72.7% 78.2% 38.0% 88.5% 90.1% 91.4% 76.5%
BYOL-A 2048-d AudioNTT2020-BYOLA-64x96d2048.pth 74.1% 79.1% 40.1% 90.2% 91.0% 92.2% 77.8%

License

This implementation is for your evaluation of BYOL-A paper, see LICENSE for the detail.

Acknowledgements

BYOL-A is built on top of byol-pytorch, a BYOL implementation by Phil Wang (@lucidrains). We thank Phil for open-source sophisticated code.

@misc{wang2020byol-pytorch,
  author =       {Phil Wang},
  title =        {Bootstrap Your Own Latent (BYOL), in Pytorch},
  howpublished = {\url{https://github.com/lucidrains/byol-pytorch}},
  year =         {2020}
}

References

Comments
  • Question for reproducing results

    Question for reproducing results

    Hi,

    Thanks for sharing this great work! I tried to reproduce the results using the official guidance but I failed.

    After processing the data, I run the following commands:

    CUDA_VISIBLE_DEVICES=0 python -W ignore train.py work/16k/fsd50k/FSD50K.dev_audio
    cp lightning_logs/version_4/checkpoints/epoch\=99-step\=16099.ckpt AudioNTT2020-BYOLA-64x96d2048.pth
    CUDA_VISIBLE_DEVICES=4 python evaluate.py AudioNTT2020-BYOLA-64x96d2048.pth spcv2
    

    However, the results are far from the reported results

    image

    Did I miss something important? Thank you very much.

    question 
    opened by ChenyangLEI 15
  • Evaluation on voxforge

    Evaluation on voxforge

    Hi,

    Thank you so much for your contribution. This works is very interesting and your code is easy for me to follow. But one of the downstream dataset, voxforge is missing from the preprocess_ds.py. Could you please release the code for that dataset, too?

    Thank you again for your time.

    Best regards

    opened by Huiimin5 9
  • A mistake in RunningMean

    A mistake in RunningMean

    Thank you for the fascinating paper and the code to reproduce it!

    I think there might be a problem in RunningMean. The current formula (the same in v1 and v2) looks like this:

    $$ m_n = m_{n - 1} + \frac{a_n - m_{n - 1}}{n - 1}, $$

    which is inconsistent with the correct formula listed on StackOverflow:

    $$ m_n = m_{n - 1} + \frac{a_n - m_{n - 1}}{n}. $$

    The problem is that self.n is incremented after the new mean is computed. Could you please either correct me if I am wrong or correct the code?

    opened by WhiteTeaDragon 4
  • a basic question:torch.randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3`

    a basic question:torch.randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3`

    Traceback (most recent call last):
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 2066, in <module>
        main()
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 2060, in main
        globals = debugger.run(setup['file'], None, None, is_module)
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 1411, in run
        return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 1418, in _exec
        pydev_imports.execfile(file, globals, locals)  # execute the script
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
        exec(compile(contents+"\n", file, 'exec'), glob, loc)
      File "E:/pythonSpace/byol-a/train.py", line 132, in <module>
        main(audio_dir=base_path + '1/', epochs=100)
      File "E:/pythonSpace/byol-a/train.py", line 112, in main
        learner = BYOLALearner(model, cfg.lr, cfg.shape,
      File "E:/pythonSpace/byol-a/train.py", line 56, in __init__
        self.learner = BYOL(model, image_size=shape, **kwargs)
      File "D:\min\envs\torch1_7_1\lib\site-packages\byol_pytorch\byol_pytorch.py", line 211, in __init__
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
    TypeError: randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3
    
    Not_an_issue 
    opened by a1030076395 3
  • Question about comments in the train.py

    Question about comments in the train.py

    https://github.com/nttcslab/byol-a/blob/master/train.py

    At line 67, there is comments for the shape of input.

            # in fact, it should be (B, 1, F, T), e.g. (256, 1, 64, 96) where 64 is the number of mel bins
            paired_inputs = torch.cat(paired_inputs) # [(B,1,T,F), (B,1,T,F)] -> (2*B,1,T,F)
    

    image

    However, it is different from the descriptions in config.yml file

    # Shape of loh-mel spectrogram [F, T].
    shape: [64, 96]
    
    bug 
    opened by ChenyangLEI 2
  • Doubt in paper

    Doubt in paper

    Hi there,

    Section 4, subsection A, part 1 from your paper says:

     The number of frames, T, in one segment was 96 in pretraining, which corresponds to 1,014ms. 
    

    However, the previous line says the hop size used was 10ms. So according to this 96 would mean 960ms?

    Am I understanding something wrong here?

    Thank You in advance!

    question 
    opened by Sreyan88 2
  • Random crop is not working.

    Random crop is not working.

    https://github.com/nttcslab/byol-a/blob/60cebdc514951e6b42e18e40a2537a01a39ad47b/byol_a/dataset.py#L80-L82

    If len(wav) > self.unit_length, length_adj will be a negative value. So start will be 0. If wav (before pad) is shorter than unit length, length_adj == 0 after padding. So start is always 0. So It will only perform a certain area of crop from 0 to self.unit_length (cropped_wav == wav[0: self.unit_length]), not random crop.

    So I think line 80 should be changed to length_adj = len(wav) - self.unit_length .

    bug 
    opened by JUiscoming 2
  • Doubt in RunningNorm

    Doubt in RunningNorm

    Hi There, great repo!

    I think I have misunderstood something wrong with the RunningNorm function. The function expects the size of an epoch, however, your implementation passes the size of the entire dataset.

    Is it a bug? Or is there a problem with my understanding?

    Thank You!

    question 
    opened by Sreyan88 2
  • How to interpret the performance

    How to interpret the performance

    Hi, it' s a great work, but how can I understance the performance metric? For example, VoxCeleb1 is usually for speaker verification, shouldn't we measure EER?

    opened by ranchlai 2
  • Finetuning of BYOL-A

    Finetuning of BYOL-A

    Hi,

    your paper is super interesting. I have a question regarding the downstream tasks. If I understand the paper correctly, you used a single linear layer for the downstream tasks which only used the sum of mean and max of the representation over time as input.

    Did you try to finetune BYOL-A end-to-end after pretraining to the downstream tasks? In the case of TRILL they were able to improve the performance even further by finetuning the whole model end-to-end. Is there a specific reason why this is not possible with BYOL-A?

    questions 
    opened by mschiwek 1
  • Missing scaling of validation samples in evaluate.py

    Missing scaling of validation samples in evaluate.py

    https://github.com/nttcslab/byol-a/blob/master/evaluate.py#L112

    It also needs: X_val = scaler.transform(X_val), or validation acc & loss will be invalid. This can be one of the reasons why we see lower performance when I tried to get official performances...

    bug 
    opened by daisukelab 0
Releases(v2.0.0)
Owner
NTT Communication Science Laboratories
NTT Communication Science Laboratories
MassiveSumm: a very large-scale, very multilingual, news summarisation dataset

MassiveSumm: a very large-scale, very multilingual, news summarisation dataset This repository contains links to data and code to fetch and reproduce

Daniel Varab 19 Dec 16, 2022
Assessing syntactic abilities of BERT

BERT-Syntax Assesing the syntactic abilities of BERT. What Evaluate Google's BERT-Base and BERT-Large models on the syntactic agreement datasets from

Yoav Goldberg 147 Aug 02, 2022
A lightweight tool to get an AI Infrastructure Stack up in minutes not days.

K3ai will take care of setup K8s for You, deploy the AI tool of your choice and even run your code on it.

k3ai 105 Dec 04, 2022
Implement the Pareto Optimizer and pcgrad to make a self-adaptive loss for multi-task

multi-task_losses_optimizer Implement the Pareto Optimizer and pcgrad to make a self-adaptive loss for multi-task 已经实验过了,不会有cuda out of memory情况 ##Par

14 Dec 25, 2022
The official implementation of NeurIPS 2021 paper: Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks

Introduction This repository includes the source code for "Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks", which is pu

machen 11 Nov 27, 2022
Interpretation of T cell states using reference single-cell atlases

Interpretation of T cell states using reference single-cell atlases ProjecTILs is a computational method to project scRNA-seq data into reference sing

Cancer Systems Immunology Lab 139 Jan 03, 2023
“袋鼯麻麻——智能购物平台”能够精准地定位识别每一个商品

“袋鼯麻麻——智能购物平台”能够精准地定位识别每一个商品,并且能够返回完整地购物清单及顾客应付的实际商品总价格,极大地降低零售行业实际运营过程中巨大的人力成本,提升零售行业无人化、自动化、智能化水平。

thomas-yanxin 192 Jan 05, 2023
3D detection and tracking viewer (visualization) for kitti & waymo dataset

3D detection and tracking viewer (visualization) for kitti & waymo dataset

222 Jan 08, 2023
This repository contains code and data for "On the Multimodal Person Verification Using Audio-Visual-Thermal Data"

trimodal_person_verification This repository contains the code, and preprocessed dataset featured in "A Study of Multimodal Person Verification Using

ISSAI 7 Aug 31, 2022
PolyGlot, a fuzzing framework for language processors

PolyGlot, a fuzzing framework for language processors Build We tested PolyGlot on Ubuntu 18.04. Get the source code: git clone https://github.com/s3te

Software Systems Security Team at Penn State University 79 Dec 27, 2022
This repository contains an implementation of the Permutohedral Attention Module in Pytorch

Permutohedral_attention_module This repository contains an implementation of the Permutohedral Attention Module

Samuel JOUTARD 26 Nov 27, 2022
O2O-Afford: Annotation-Free Large-Scale Object-Object Affordance Learning (CoRL 2021)

O2O-Afford: Annotation-Free Large-Scale Object-Object Affordance Learning Object-object Interaction Affordance Learning. For a given object-object int

Kaichun Mo 26 Nov 04, 2022
Codebase for the Summary Loop paper at ACL2020

Summary Loop This repository contains the code for ACL2020 paper: The Summary Loop: Learning to Write Abstractive Summaries Without Examples. Training

Canny Lab @ The University of California, Berkeley 44 Nov 04, 2022
[ICLR 2022 Oral] F8Net: Fixed-Point 8-bit Only Multiplication for Network Quantization

F8Net Fixed-Point 8-bit Only Multiplication for Network Quantization (ICLR 2022 Oral) OpenReview | arXiv | PDF | Model Zoo | BibTex PyTorch implementa

Snap Research 76 Dec 13, 2022
From the basics to slightly more interesting applications of Tensorflow

TensorFlow Tutorials You can find python source code under the python directory, and associated notebooks under notebooks. Source code Description 1 b

Parag K Mital 5.6k Jan 09, 2023
这是一个利用facenet和retinaface实现人脸识别的库,可以进行在线的人脸识别。

Facenet+Retinaface:人脸识别模型在Pytorch当中的实现 目录 注意事项 Attention 所需环境 Environment 文件下载 Download 预测步骤 How2predict 参考资料 Reference 注意事项 该库中包含了两个网络,分别是retinaface和

Bubbliiiing 102 Dec 30, 2022
Flexible-CLmser: Regularized Feedback Connections for Biomedical Image Segmentation

Flexible-CLmser: Regularized Feedback Connections for Biomedical Image Segmentation The skip connections in U-Net pass features from the levels of enc

Boheng Cao 1 Dec 29, 2021
PySOT - SenseTime Research platform for single object tracking, implementing algorithms like SiamRPN and SiamMask.

PySOT is a software system designed by SenseTime Video Intelligence Research team. It implements state-of-the-art single object tracking algorit

STVIR 4.1k Dec 29, 2022
UltraGCN: An Ultra Simplification of Graph Convolutional Networks for Recommendation

UltraGCN This is our Pytorch implementation for our CIKM 2021 paper: Kelong Mao, Jieming Zhu, Xi Xiao, Biao Lu, Zhaowei Wang, Xiuqiang He. UltraGCN: A

XUEPAI 93 Jan 03, 2023
Neural Cellular Automata + CLIP

🧠 Text-2-Cellular Automata Using Neural Cellular Automata + OpenAI CLIP (Work in progress) Examples Text Prompt: Cthulu is watching cthulu_is_watchin

Mainak Deb 21 Dec 19, 2022