Powerful unsupervised domain adaptation method for dense retrieval.

Overview

Generative Pseudo Labeling (GPL)

GPL is an unsupervised domain adaptation method for training dense retrievers. It is based on query generation and pseudo labeling with powerful cross-encoders. To train a domain-adapted model, it needs only the unlabeled target corpus and can achieve significant improvement over zero-shot models.

For more information, checkout our publication:

Installation

One can either install GPL via pip

pip install gpl

or via git clone

git clone https://github.com/UKPLab/gpl.git && cd gpl
pip install -e .

Usage

GPL accepts data in the BeIR-format. For example, we can download the FiQA dataset hosted by BeIR:

wget https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip
unzip fiqa.zip
head -n 2 fiqa/corpus.jsonl  # One can check this data format. Actually GPL only need this `corpus.jsonl` as data input for training.

Then we can either use the python -m function to run GPL training directly:

export dataset="fiqa"
python -m gpl.train \
    --path_to_generated_data "generated/$dataset" \
    --base_ckpt 'distilbert-base-uncased' \
    --batch_size_gpl 32 \
    --gpl_steps 140000 \
    --output_dir "output/$dataset" \
    --evaluation_data "./$dataset" \
    --evaluation_output "evaluation/$dataset" \
    --generator "BeIR/query-gen-msmarco-t5-base-v1" \
    --retrievers "msmarco-distilbert-base-v3" "msmarco-MiniLM-L-6-v3" \
    --cross_encoder "cross-encoder/ms-marco-MiniLM-L-6-v2" \
    --qgen_prefix "qgen" \
    --do_evaluation \
    # --use_amp   # Use this for efficient training if the machine supports AMP

# One can run `python -m gpl.train --help` for the information of all the arguments
# To reproduce the experiments in the paper, set `base_ckpt` to "GPL/msmarco-distilbert-margin-mse" (https://huggingface.co/GPL/msmarco-distilbert-margin-mse)

or import GPL's trainining method in a python script:

import gpl

dataset = 'fiqa'
gpl.train(
    path_to_generated_data=f"generated/{dataset}",
    base_ckpt='distilbert-base-uncased',  
    # base_ckpt='GPL/msmarco-distilbert-margin-mse',  # The starting checkpoint of the experiments in the paper
    batch_size_gpl=32,
    gpl_steps=140000,
    output_dir=f"output/{dataset}",
    evaluation_data=f"./{dataset}",
    evaluation_output=f"evaluation/{dataset}",
    generator="BeIR/query-gen-msmarco-t5-base-v1",
    retrievers=["msmarco-distilbert-base-v3", "msmarco-MiniLM-L-6-v3"],
    cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
    qgen_prefix="qgen",
    do_evaluation=True,
    # --use_amp   # One can use this flag for enabling the efficient float16 precision
)

How does GPL work?

The workflow of GPL is shown as follows:

  1. GPL first use a seq2seq (we use BeIR/query-gen-msmarco-t5-base-v1 by default) model to generate queries_per_passage queries for each passage in the unlabeled corpus. The query-passage pairs are viewed as positive examples for training.

    Result files (under path $path_to_generated_data): (1) ${qgen}-qrels/train.tsv, (2) ${qgen}-queries.jsonl and also (3) corpus.jsonl (copied from $evaluation_data/);

  2. Then, it runs negative mining with the generated queries as input on the target corpus. The mined passages will be viewed as negative examples for training. One can specify any dense retrievers (SBERT or Huggingface/transformers checkpoints, we use msmarco-distilbert-base-v3 + msmarco-MiniLM-L-6-v3 by default) or BM25 to the argument retrievers as the negative miner.

    Result file (under path $path_to_generated_data): hard-negatives.jsonl;

  3. Finally, it does pseudo labeling with the powerful cross-encoders (we use cross-encoder/ms-marco-MiniLM-L-6-v2 by default.) on the query-passage pairs that we have so far (for both positive and negative examples).

    Result file (under path $path_to_generated_data): gpl-training-data.tsv. It contains (gpl_steps * batch_size_gpl) tuples in total.

Up to now, we have the actual training data ready. One can look at sample-data/generated/fiqa for a quick example about the data format. The very last step is to apply the MarginMSE loss to teach the student retriever to mimic the margin scores, CE(query, positive) - CE(query, negative) labeled by the teacher model (Cross-Encoder, CE).

Customized data

One can also replace/put the customized data for any intermediate step under the path $path_to_generated_data with the same name fashion. GPL will skip the intermediate steps by using these provided data.

Citation

If you use the code for evaluation, feel free to cite our publication GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval:

@article{wang2021gpl,
    title = "GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval",
    author = "Kexin Wang and Nandan Thakur and Nils Reimers and Iryna Gurevych", 
    journal= "arXiv preprint arXiv:2112.07577",
    month = "4",
    year = "2021",
    url = "https://arxiv.org/abs/2112.07577",
}

Contact person and main contributor: Kexin Wang, [email protected]

https://www.ukp.tu-darmstadt.de/

https://www.tu-darmstadt.de/

Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions.

This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication.

Comments
  • Error while running the training script

    Error while running the training script

    2022-04-14 06:00:25] INFO [gpl.toolkit.pl.run:60] Begin pseudo labeling 0%| | 0/140000 [00:00<?, ?it/s] Traceback (most recent call last): File "/home/ec2-user/SageMaker/gpl/gpl/toolkit/pl.py", line 63, in run batch = next(hard_negative_iterator) File "/home/ec2-user/SageMaker/kernels/gpl_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 530, in next data = self._next_data() File "/home/ec2-user/SageMaker/kernels/gpl_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 569, in _next_data index = self._next_index() # may raise StopIteration File "/home/ec2-user/SageMaker/kernels/gpl_env/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in _next_index return next(self._sampler_iter) # may raise StopIteration StopIteration

    opened by kingafy 3
  • Loss function

    Loss function

    Is it a typo of having the minus sign "-" in the MarginMSE loss function in Equation (1) in the GPL paper?

    There should be no minus sign "-". Because the model should minimize the MSE(delta_teacher, delta_student), not maximize it. I checked the released code of GPL, the loss function is without the minus sign "-".

    image image
    opened by dli1 2
  • GPU speedup

    GPU speedup

    I recon this is more of a generic question for TSADE + GPL (or any transformer used) , but can you use GPU by simply doing something like gpl.to(device)?

    opened by ahadda5 1
  • [KTLO-6] Hints for missing evaluation data

    [KTLO-6] Hints for missing evaluation data

    The previous code does not give enough hint about missing evaluation data

    • gpl/toolkit/evaluation.py: Added checking for missing evaluation data
    • tests/unit/conftest.py: Separated sbert and sbert_path fixtures
    • tests/unit/test_eval.py: Added test
    opened by kwang2049 0
  • [KTLO-5] batch size larger than data size

    [KTLO-5] batch size larger than data size

    The previous code did not check whether the batch size is larger than the number of data points (or number of generated queries) in PseudoLabeler.run

    • pl/toolkit/pl.py: Added check at the beginning of run about batch size vs data size
    • tests/unit/test_pl.py: Added test
    opened by kwang2049 0
  • [KTLO-4] OOM error in qgen

    [KTLO-4] OOM error in qgen

    Previous code does not detect OOM error in QGen, which might be due to large QPP or batch size

    modified: gpl/toolkit/qgen.py: Added try catch new file: tests/unit/test_qgen.py: Added test

    opened by kwang2049 0
  • [KTLO-3] OOM error in loadable checking

    [KTLO-3] OOM error in loadable checking

    The current version could not identify OOM error in loadable_by_sbert_oom, since OOM is also a runtime error and this loadable checking views all runtime errors as not loadable

    • modified: gpl/toolkit/sbert.py: Raise OOM error (runtime error)
    • modified: setup.py: Added pytest
    • new file: tests/unit/conftest.py: SBERT fixture
    • new file: tests/unit/test_sbert.py: Test OOM error case
    opened by kwang2049 0
  • [KTLO-0] New EES version and black formatting

    [KTLO-0] New EES version and black formatting

    • README.md: Hint of installing PyTorch correctly wrt. the CUDA version.
    • gpl/toolkit/beir.py: Black
    • gpl/toolkit/dataset.py: Black
    • gpl/toolkit/evaluation.py: Black
    • gpl/toolkit/log.py: Black
    • gpl/toolkit/loss.py: Black
    • gpl/toolkit/mine.py: Black
    • gpl/toolkit/mnrl.py: Black
    • gpl/toolkit/pl.py: Black
    • gpl/toolkit/qgen.py: Black
    • gpl/toolkit/reformat.py: Black
    • gpl/toolkit/rescale.py: Black
    • gpl/toolkit/resize.py: Black
    • gpl/toolkit/sbert.py: Black
    • gpl/train.py: Black
    • setup.py: Added protobuf, required by T5 and seems to be ignored by simply installing transformer; specified ees>=0.0.8 (where the es version is kept the same with that required by beir)
    opened by kwang2049 0
  • Should the leaning domain contain only assertion texts (like

    Should the leaning domain contain only assertion texts (like "Python is a high-level general-purpose programming language")?

    Hi. Should the leaning domain contain only assertion texts (like "Python is a high-level general-purpose programming language" in your example)? In your pipeline the first step is Query Generation: For a given text from our domain, we first use a T5 model that generates a possible query for the given text. E.g. when your text is “Python is a high-level general-purpose programming language”, the model might generate a query like “What is Python”. You can find various query generators on our doc2query-hub. Does that mean that texts which couldn't be converted into queries (e.g. "Investment consulting for legal entities and individuals.") cannot be used for training?

    opened by edgar2597 0
  • GPL for sentence embedding tasks?

    GPL for sentence embedding tasks?

    In the provided examples GPL us used for semantic search tasks: given a query, relevant results should be retrieved. Is it also the recommended approach to get meaningful embeddings / bi-encoders, or is it better to use TSDAE?

    opened by hanshupe 2
  • Guidance on gpl_stapes, new_size and batch_size_gpl

    Guidance on gpl_stapes, new_size and batch_size_gpl

    Hello,

    I am looking for some guidance on below parameters of gpl.train().

    • gpl_stapes - Do we need such a huge value of 140000 for corpus of size 1300?
    • new_size
    • batch_size_gpl - would it help to speed up the training if we keep this as 64 or 128? How to derive the values of these parameters based on dataset or corpus.jsonl?
    opened by MyBruso 0
  • TSDAE to GPL... Error on start

    TSDAE to GPL... Error on start

    I'm trying to go from my trained TSDAE and then apply GPL... However, keep getting errors.

    ! export dataset="hs_resume_tsdae_gpl_mini" 
    ! python -m gpl.train \
        --path_to_generated_data "generated/$dataset" \
        --base_ckpt "/Users/cfeld/Desktop/dev/trajectory/finetuning/gpl/outputs/tsdae/MiniLM-L6-H384-uncased-model" \
        --gpl_score_function "dot" \
        --batch_size_gpl 34 \
        --gpl_steps 100 \
        --queries_per_passage 1 \
        --output_dir "output/$dataset" \
        --evaluation_data "./$dataset" \
        --evaluation_output "evaluation/$dataset" \
        --generator "BeIR/query-gen-msmarco-t5-base-v1" \
        --retrievers "msmarco-distilbert-base-v3" "msmarco-MiniLM-L-6-v3" \
        --retriever_score_functions "cos_sim" "cos_sim" \
        --cross_encoder "cross-encoder/ms-marco-MiniLM-L-6-v2" \
        --use_train_qrels
    

    However, I'm getting this error:

    2022-09-12 17:37:44 - Loading faiss.
    2022-09-12 17:37:44 - Successfully loaded faiss.
    /opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/runpy.py:127: RuntimeWarning: 'gpl.train' found in sys.modules after import of package 'gpl', but prior to execution of 'gpl.train'; this may result in unpredictable behaviour
      warn(RuntimeWarning(msg))
    [2022-09-12 17:37:44] INFO [gpl.train.train:79] Corpus does not exist in generated/. Now clone the one from the evaluation path ./
    [2022-09-12 17:37:44] WARNING [gpl.train.train:106] Found `qgen_prefix` is not None. By setting `use_train_qrels == True`, the `qgen_prefix` will not be used
    [2022-09-12 17:37:44] INFO [gpl.train.train:113] Loading qrels and queries from labeled data under the path of `evaluation_data`
    Traceback (most recent call last):
      File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/runpy.py", line 197, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/site-packages/gpl/train.py", line 250, in <module>
        train(**vars(args))
      File "/opt/homebrew/Caskroom/miniconda/base/envs/finetune_hs/lib/python3.9/site-packages/gpl/train.py", line 114, in train
        assert 'qrels' in os.listdir(evaluation_data) and 'queries.jsonl' in os.listdir(evaluation_data)
    AssertionError
    

    Perhaps my folder structure isn't quite right? I've tried all kinds of combos... Folder: corpus.jsonl evaluation - corpus.jsonl - hs_resume_tsdae_gpl_mini -- corpus.jsonl generated - corpus.jsonl - hs_resume_tsdae_gpl_mini -- corpus.jsonl hs_resume_tsdae_gpl_mini - corpus.jsonl output - hs_resume_tsdae_gpl_mini

    opened by christophermfeld 1
  • Evaluation data format

    Evaluation data format

    Hi,

    1/ How should the evaluation data format be as passed in the evaluation_data argument? Could you provide me some example of evaluation data and how it should be formatted?

    2/ How does the evaluation work on these data? What are the tests passed and labels used?

    Thanks!

    opened by Matthieu-Tinycoaching 0
  • RuntimeError: CUDA out of memory

    RuntimeError: CUDA out of memory

    Hi,

    When trying to generate intermediate results with the following command:

    dataset = 'tiny'
    gpl.train(
        path_to_generated_data=f"generated/{dataset}",
        base_ckpt='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2',  
        # base_ckpt='GPL/msmarco-distilbert-margin-mse',  # The starting checkpoint of the experiments in the paper
        gpl_score_function="dot",
        # Note that GPL uses MarginMSE loss, which works with dot-product
        batch_size_gpl=32,
        gpl_steps=140000,
        new_size=-1,
        # Resize the corpus to `new_size` (|corpus|) if needed. When set to None (by default), the |corpus| will be the full size. When set to -1, the |corpus| will be set automatically: If QPP * |corpus| <= 250K, |corpus| will be the full size; else QPP will be set 3 and |corpus| will be set to 250K / 3
        queries_per_passage=-1,
        # Number of Queries Per Passage (QPP) in the query generation step. When set to -1 (by default), the QPP will be chosen automatically: If QPP * |corpus| <= 250K, then QPP will be set to 250K / |corpus|; else QPP will be set 3 and |corpus| will be set to 250K / 3
        output_dir=f"output/{dataset}",
        evaluation_data=f"./{dataset}",
        evaluation_output=f"evaluation/{dataset}",
        generator="BeIR/query-gen-msmarco-t5-large-v1",
        retrievers=["msmarco-distilbert-base-tas-b", "msmarco-MiniLM-L6-cos-v5"],
        retriever_score_functions=["dot", "cos_sim"],
        # Note that these two retriever model work with cosine-similarity
        cross_encoder="cross-encoder/ms-marco-MiniLM-L-6-v2",
        qgen_prefix="qgen",
        # This prefix will appear as part of the (folder/file) names for query-generation results: For example, we will have "qgen-qrels/" and "qgen-queries.jsonl" by default.
        do_evaluation=True,
        use_amp=True   # One can use this flag for enabling the efficient float16 precision
    )
    

    I got the following error:

    2022-08-26 11:55:08 - Loading faiss with AVX2 support.
    2022-08-26 11:55:08 - Could not load library with AVX2 support due to:
    ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'")
    2022-08-26 11:55:08 - Loading faiss.
    2022-08-26 11:55:08 - Successfully loaded faiss.
    [2022-08-26 11:55:10] INFO [gpl.train.train:79] Corpus does not exist in generated/tiny. Now clone the one from the evaluation path ./tiny
    [2022-08-26 11:55:10] INFO [gpl.train.train:84] Automatically set `new_size` to 83334
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:89] Loading Corpus...
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4252/4252 [00:00<00:00, 277639.61it/s]
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:91] Loaded 4252 Documents.
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:92] Doc Example: {'text': 'Without a specific goal for your speech, your audience will be lost in understanding the message you are seeking to deliver, because you will not know yourself what you are seeking to deliver in that speech.', 'title': ''}
    [2022-08-26 11:55:10] WARNING [gpl.toolkit.resize.resize:19] `new_size` should be smaller than the corpus size
    [2022-08-26 11:55:10] INFO [gpl.toolkit.resize.resize:41] Resized the corpus in ./tiny to generated/tiny with new size 83334
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:89] Loading Corpus...
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4252/4252 [00:00<00:00, 321974.74it/s]
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:91] Loaded 4252 Documents.
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:92] Doc Example: {'text': 'Without a specific goal for your speech, your audience will be lost in understanding the message you are seeking to deliver, because you will not know yourself what you are seeking to deliver in that speech.', 'title': ''}
    [2022-08-26 11:55:10] INFO [gpl.train.train:99] Automatically set `queries_per_passage` to 59
    [2022-08-26 11:55:10] INFO [gpl.train.train:125] No generated queries found. Now generating it
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:89] Loading Corpus...
    100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4252/4252 [00:00<00:00, 308459.11it/s]
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:91] Loaded 4252 Documents.
    [2022-08-26 11:55:10] INFO [beir.datasets.data_loader.load_corpus:92] Doc Example: {'text': 'Without a specific goal for your speech, your audience will be lost in understanding the message you are seeking to deliver, because you will not know yourself what you are seeking to deliver in that speech.', 'title': ''}
    [2022-08-26 11:55:20] INFO [beir.generation.models.auto_model.__init__:16] Use pytorch device: cuda
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:40] Starting to Generate 59 Questions Per Passage using top-p (nucleus) sampling...
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:41] Params: top_p = 0.95
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:42] Params: top_k = 25
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:43] Params: max_length = 64
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:44] Params: ques_per_passage = 59
    [2022-08-26 11:55:21] INFO [beir.generation.generate.generate:45] Params: batch size = 32
    pas:   0%|                                                                                                                                                                                          | 0/133 [00:00<?, ?it/s]
    Traceback (most recent call last):
      File "/home/matthieu/Tinycoaching/GPL/v.0.1.0/gpl_query_generation.py", line 316, in <module>
        gpl.train(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/gpl/train.py", line 127, in train
        qgen(path_to_generated_data, path_to_generated_data, generator_name_or_path=generator, ques_per_passage=queries_per_passage, bsz=batch_size_generation, qgen_prefix=qgen_prefix)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/gpl/toolkit/qgen.py", line 23, in qgen
        generator.generate(corpus, output_dir=output_dir, ques_per_passage=ques_per_passage, prefix=prefix, batch_size=bsz)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/beir/generation/generate.py", line 54, in generate
        queries = self.model.generate(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/beir/generation/models/auto_model.py", line 28, in generate
        outs = self.model.generate(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
        return func(*args, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/generation_utils.py", line 1326, in generate
        return self.sample(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/generation_utils.py", line 1944, in sample
        outputs = self(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 1639, in forward
        decoder_outputs = self.decoder(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 1035, in forward
        layer_outputs = layer_module(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 692, in forward
        cross_attention_outputs = self.layer[1](
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 606, in forward
        attention_output = self.EncDecAttention(
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/matthieu/anaconda3/envs/gpl_0.1.0/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 509, in forward
        scores = torch.matmul(
    RuntimeError: CUDA out of memory. Tried to allocate 584.00 MiB (GPU 0; 23.70 GiB total capacity; 20.69 GiB already allocated; 587.94 MiB free; 20.83 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
    

    My corpus consists of small paragraphs of 3-4 lines and I used use_amp option. How could I deal with it?

    opened by Matthieu-Tinycoaching 1
Releases(v0.1.4)
  • v0.1.4(Sep 29, 2022)

  • v0.1.3(Sep 26, 2022)

    Previously, there was a conflict between easy_elasticsearch and beir on the dependency of elasticsearch:

    • easy_elasticsearch requires elasticsearch==7.12.1 while
    • beir requires elasticserch==7.9.1

    In the lastest version of easy_elasticsearch, the requirements have been changed to solve this issue. Here we update gpl to install this version (easy_elasticsearch==0.0.9). Another update of easy_elasticsearch==0.0.9 is that it has solved the issue that ES could return empty results (due to refresh is not called for indexing)

    Source code(tar.gz)
    Source code(zip)
    gpl-0.1.3-py3-none-any.whl(28.79 KB)
    gpl-0.1.3.tar.gz(22.88 KB)
  • v0.1.0(Apr 19, 2022)

    Updated paper, accepted by NAACL 2022

    The GPL paper has been accepted by NAACL 2022! Major updates:

    • Improved the setting: Down-sampled the corpus if it is too large; calculate the number of generated queries according to the corpus size;
    • Added more analysis about the influence of the number of generated queries: Small corpus needs more queries;
    • Added results on the full 18 BeIR datasets: The conclusions remain the same, while we also tried training GPL on top of the power TAS-B model and achieved new improvements.

    Automatic hyper-parameter

    Previously, we use the whole corpus and number of generated queries = 3, no matter the corpus size. This actually results in a very bad training efficiency for large corpus. In the new version, we automatically set these two hyper-parameters by meeting the standard: the total number of generated queries = 250K.

    In detail, we first set the queries_per_passage >= 3 and uniformly down-sample the corpus if 3 × |C| > 250K, where |C| is the corpus size; then we calculate queries_per_passage = 250K/|C|. For example, the queries_per_passage values for FiQA (original size = 57.6K) and Robust04 (original size = 528.2K) are 5 and 3, resp. and the Robust04 corpus is down-sampled to 83.3K.

    Released checkpoints (TAS-B ones)

    We now release the pre-trained GPL models via the https://huggingface.co/GPL. They also include the power GPL models trained on top of TAS-B.

    Source code(tar.gz)
    Source code(zip)
    gpl-0.1.0-py3-none-any.whl(27.99 KB)
    gpl-0.1.0.tar.gz(22.13 KB)
  • v0.0.9(Jan 11, 2022)

    Fixed bug of max.-sequence-length mismatch between student and teacher

    Previously, the teacher (i.e. the cross-encoder) got the input of the concatenation of query and document texts and had no limits of max. sequence length (cf. here and here). However, the students actually had the limits of max. sequence length on both query texts and document texts separately. This causes the mismatch between the information which can be seen by the student and the teacher models.

    In the new release, we fixed this by doing "retokenization": Right before pseudo labeling, we let the tokenizer of the teacher model tokenize the query texts and the document texts also separately and then decode the results (token IDs) back into the texts again. The resulting texts will meet the same max.-sequence-length requirements as the student model does and thus fix this bug.

    Keep full precision of the pseudo labels

    Previously, we saved the pseudo labels from PyTorch's tensors directly, which would not give the full precision. Now we have fixed this by doing labels.tolist() right before the data dumping. This actually would not influence a lot, since previously it kept 6-digit precision and was high enough.

    Source code(tar.gz)
    Source code(zip)
    gpl-0.0.9-py3-none-any.whl(23.56 KB)
    gpl-0.0.9.tar(18.38 KB)
  • v0.0.8(Dec 20, 2021)

    Independent evaluation and k_values supported

    One can now run the gpl.toolkit.evaluation directly. Previously, it was only possible as part of the whole gpl.train workflow. Please check this example for more details.

    And we have also added argument k_values in gpl.toolkit.evaluation.evaluate. This is for specifying the K values in "[email protected]", "[email protected]", etc.

    Fixed bugs & use load_sbert in mnrl and evaluation

    Now almost all methods that require a separation token has this argument called sep (previously it was fixed as a blank token " "). Two exceptions are mnrl (a loss function in SBERT repo, also the default training loss for the QGen method) and qgen, since they are from the BeIR repo (we will update the BeIR repo in the future if possible).

    Source code(tar.gz)
    Source code(zip)
    gpl-0.0.8-py3-none-any.whl(23.12 KB)
    gpl-0.0.8.tar(17.96 KB)
  • v0.0.7(Dec 17, 2021)

    Rewrite SBERT loading

    Previously, GPL loads starting checkpoints (--base_ckpt) by constructing SBERT model from scratch. This way would lose some information of the checkpoint (e.g. pooling and max_seq_length), and one needed to specify them carefully.

    Now we have created another method called load_sbert. It will use SentenceTransformer(base_ckpt) to load the checkpoint directly and do some checking & assertions. Loading from a Huggingface-format checkpoint (e.g. "distilbert-base-uncased") now is still possible for many cases as previous, but we do recommend users to load from a SBERT-format if possible, since it will be less likely to misuse the starting checkpoint.

    Reformatting examples

    In some cases, Huggingface-format checkpoint cannot be loaded directly by SBERT, e.g. "facebook/dpr-question_encoder-single-nq-base". This is because:

    1. Of course, they are not in SBERT-format but in Hugginface-format;
    2. And for Huggingface-format, SBERT can only work with the checkpoint with a Transformer layer as the last layer, i.e. the outputs should contain hidden states with shape (batch_size, sequence_length, hidden_dimenstion).

    To use these checkpoints, one needs to reformat them into SBERT-format. We have provided two examples/templates in the new toolkit source file, gpl/toolkit/reformat.py. Please refer to its readme here.

    Solved logging bug

    Previously, the logging in GPL is overridden by some other loggers and the formatting cannot display as we want. Now we have solved this by dealing with the root logger. And the new formatting will show many usefull details:

    fmt='[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s'
    
    Source code(tar.gz)
    Source code(zip)
    gpl-0.0.7-py3-none-any.whl(22.72 KB)
    gpl-0.0.7.tar(17.81 KB)
Owner
Ubiquitous Knowledge Processing Lab
Ubiquitous Knowledge Processing Lab
Official pytorch implementation of the paper: "SinGAN: Learning a Generative Model from a Single Natural Image"

SinGAN Project | Arxiv | CVF | Supplementary materials | Talk (ICCV`19) Official pytorch implementation of the paper: "SinGAN: Learning a Generative M

Tamar Rott Shaham 3.2k Dec 25, 2022
A customisable game where you have to quickly click on black tiles in order of appearance while avoiding clicking on white squares.

W.I.P-Aim-Memory-Game A customisable game where you have to quickly click on black tiles in order of appearance while avoiding clicking on white squar

dE_soot 1 Dec 08, 2021
The official PyTorch code implementation of "Human Trajectory Prediction via Counterfactual Analysis" in ICCV 2021.

Human Trajectory Prediction via Counterfactual Analysis (CausalHTP) The official PyTorch code implementation of "Human Trajectory Prediction via Count

46 Dec 03, 2022
[CVPR'21] Locally Aware Piecewise Transformation Fields for 3D Human Mesh Registration

Locally Aware Piecewise Transformation Fields for 3D Human Mesh Registration This repository contains the implementation of our paper Locally Aware Pi

sfwang 70 Dec 19, 2022
Code for the paper "Relation of the Relations: A New Formalization of the Relation Extraction Problem"

This repo contains the code for the EMNLP 2020 paper "Relation of the Relations: A New Paradigm of the Relation Extraction Problem" (Jin et al., 2020)

YYY 27 Oct 26, 2022
Proposed n-stage Latent Dirichlet Allocation method - A Novel Approach for LDA

n-stage Latent Dirichlet Allocation (n-LDA) Proposed n-LDA & A Novel Approach for classical LDA Latent Dirichlet Allocation (LDA) is a generative prob

Anıl Güven 4 Mar 07, 2022
李云龙二次元风格化!打滚卖萌,使用了animeGANv2进行了视频的风格迁移

李云龙二次元风格化!一键star、fork,你也可以生成这样的团长! 打滚卖萌求star求fork! 0.效果展示 视频效果前往B站观看效果最佳:李云龙二次元风格化: github开源repo:李云龙二次元风格化 百度AIstudio开源地址,一键fork即可运行: 李云龙二次元风格化!一键fork

oukohou 44 Dec 04, 2022
Repository for tackling Kaggle Ultrasound Nerve Segmentation challenge using Torchnet.

Ultrasound Nerve Segmentation Challenge using Torchnet This repository acts as a starting point for someone who wants to start with the kaggle ultraso

Qure.ai 46 Jul 18, 2022
Code for CoMatch: Semi-supervised Learning with Contrastive Graph Regularization

CoMatch: Semi-supervised Learning with Contrastive Graph Regularization (Salesforce Research) This is a PyTorch implementation of the CoMatch paper [B

Salesforce 107 Dec 14, 2022
A Pytorch Implementation of Domain adaptation of object detector using scissor-like networks

A Pytorch Implementation of Domain adaptation of object detector using scissor-like networks Please follow Faster R-CNN and DAF to complete the enviro

2 Oct 07, 2022
EDCNN: Edge enhancement-based Densely Connected Network with Compound Loss for Low-Dose CT Denoising

EDCNN: Edge enhancement-based Densely Connected Network with Compound Loss for Low-Dose CT Denoising By Tengfei Liang, Yi Jin, Yidong Li, Tao Wang. Th

workingcoder 115 Jan 05, 2023
Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR)

Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR) This is the official implementation of our paper Personalized Tran

Yongchun Zhu 81 Dec 29, 2022
A new data augmentation method for extreme lighting conditions.

Random Shadows and Highlights This repo has the source code for the paper: Random Shadows and Highlights: A new data augmentation method for extreme l

Osama Mazhar 35 Nov 26, 2022
Deep-learning-roadmap - All You Need to Know About Deep Learning - A kick-starter

Deep Learning - All You Need to Know Sponsorship To support maintaining and upgrading this project, please kindly consider Sponsoring the project deve

Instill AI 4.4k Dec 26, 2022
MRQy is a quality assurance and checking tool for quantitative assessment of magnetic resonance imaging (MRI) data.

Front-end View Backend View Table of Contents Description Prerequisites Running Basic Information Measurements User Interface Feedback and usage Descr

Center for Computational Imaging and Personalized Diagnostics 58 Dec 02, 2022
Open AI's Python library

OpenAI Python Library The OpenAI Python library provides convenient access to the OpenAI API from applications written in the Python language. It incl

Pavan Ananth Sharma 3 Jul 10, 2022
ReLoss - Official implementation for paper "Relational Surrogate Loss Learning" ICLR 2022

Relational Surrogate Loss Learning (ReLoss) Official implementation for paper "R

Tao Huang 31 Nov 22, 2022
Simple improvement of VQVAE that allow to generate x2 sized images compared to baseline

vqvae_dwt_distiller.pytorch Simple improvement of VQVAE that allow to generate x2 sized images compared to baseline. It allows to generate 512x512 ima

Sergei Belousov 25 Jul 19, 2022
A rough implementation of the paper "A Steering Algorithm for Redirected Walking Using Reinforcement Learning"

A rough implementation of the paper "A Steering Algorithm for Redirected Walking Using Reinforcement Learning"

Somnus `Chen 2 Jun 09, 2022
A Sign Language detection project using Mediapipe landmark detection and Tensorflow LSTM's

sign-language-detection A Sign Language detection project using Mediapipe landmark detection and Tensorflow LSTM. The project is built for a vocabular

Hashim 4 Feb 06, 2022