Repo for WWW 2022 paper: Progressively Optimized Bi-Granular Document Representation for Scalable Embedding Based Retrieval

Related tags

Deep LearningBiDR
Overview

BiDR

Repo for WWW 2022 paper: Progressively Optimized Bi-Granular Document Representation for Scalable Embedding Based Retrieval.

Requirements

torch==1.7
transformers==4.6
faiss-gpu==1.6.4.post2

Data Download and Preprocess

bash download_data.sh
python preprocess.py

These commands will download and preprocess the MSMARCO Passage and Doc dataset, then the resutls will be saved to ./data.
We take the Passage dataset as the example to show the running workflow.

Conventional Workflow

Representation Learning

Train the encoder with random negative (or set --hardneg_json to provied bm25/hard negatives) :

mkdir log
dataset=passage
savename=dense_global_model
python train.py --model_name_or_path roberta-base \
--max_query_length 24 --max_doc_length 128 \
--data_dir ./data/${dataset}/preprocess \
--learning_rate 1e-4 --optimizer_str adamw \
--per_device_train_batch_size 128 \
--per_query_neg_num 1 \
--generate_batch_method random \
--loss_method multi_ce  \
--savename ${savename} --save_model_path ./model \
--world_size 8 --gpu_rank 0_1_2_3_4_5_6_7  --master_port 13256 \
--num_train_epochs 30  \
--use_pq False \
|tee ./log/${savename}.log

Unsupervised Quantization

Generate dense embeddings of queries and docs:

data_type=passage
savename=dense_global_model
epoch=20
python ./inference.py \
--data_type ${data_type} \
--preprocess_dir ./data/${data_type}/preprocess/ \
--max_doc_length 256 --max_query_length 32 \
--eval_batch_size 512 \
--ckpt_path ./model/${savename}/${epoch}/ \
--output_dir  evaluate/${savename}_${epoch} 

Product Quantization based on Faiss and recall performance:

data_type=passage
savename=dense_global_model
epoch=20
python ./test/lightweight_ann.py \
--output_dir ./data/${data_type}/evaluate/${savename}_${epoch} \
--ckpt_path /model/${savename}/${epoch}/ \
--subvector_num 96 \
--index opq \
--topk 1000 \
--data_type ${data_type} \
--MRR_cutoff 10 \
--Recall_cutoff 5 10 30 50 100

Progressively Optimized Bi-Granular Document Representation

Sparse Representation Learning

Instead of running unsupervised quantization for the well-learned dense embeddings, the sparse embeddings are generated from contrastive learning, which optimizes the global discrimination and helps to enable high-quality answers to be covered in candidate search.

Train

We find that using Faiss OPQ to initialize the PQ module has a significant gain for MSMARCO dataset. But for the largest dataset: Ads dataset, initialization with Faiss OPQ is redundant and has no promotion.

dataset=passage
savename=sparse_global_model
python train.py --model_name_or_path ./model/dense_global_model/20 \
--max_query_length 24 --max_doc_length 128 \
--data_dir ./data/${dataset}/preprocess \
--learning_rate 1e-4 --optimizer_str adamw \
--per_device_train_batch_size 128 \
--per_query_neg_num 1 \
--generate_batch_method random \
--loss_method multi_ce  \
--savename ${savename} --save_model_path ./model \
--world_size 8 --gpu_rank 0_1_2_3_4_5_6_7  --master_port 13256 \
--num_train_epochs 30  \
--use_pq True \
--init_index_path ./data/${data_type}/evaluate/dense_global_model_20/OPQ96,PQ96x8.index \
--partition 96 --centroids 256 --quantloss_weight 0.0 \
|tee ./log/${savename}.log

where the ./model/dense_global_model/20 and ./data/${data_type}/evaluate/dense_global_model_20/OPQ96,PQ96x8.index is generated by conventional workflow.

Test

data_type=passage
savename=sparse_global_model
epoch=20

python ./inference.py \
--data_type ${data_type} \
--preprocess_dir ./data/${data_type}/preprocess/ \
--max_doc_length 256 --max_query_length 32 \
--eval_batch_size 512 \
--ckpt_path ./model/${savename}/${epoch}/ \
--output_dir  evaluate/${savename}_${epoch} 

python ./test/lightweight_ann.py \
--output_dir ./data/${data_type}/evaluate/${savename}_${epoch} \
--subvector_num 96 \
--index opq \
--topk 1000 \
--data_type ${data_type} \
--MRR_cutoff 10 \
--Recall_cutoff 5 10 30 50 100 \
--ckpt_path ./model/${savename}/${epoch}/ \
--init_index_path ./data/${data_type}/evaluate/dense_global_model_20/OPQ96,PQ96x8.index

Dense Representation Learning

The dense embeddings are optimized based on the candidate distribution generated by sparse embeddings. We propose a novel sampling strategy called locality-centric sampling to enhance local discrimination: construct a bipartite proximity graph and conduct random walk or snow sample on it.

Train

Encode the quries in train set and generate the candidates for all train queries:

data_type=passage
savename=sparse_global_model
epoch=20

python ./inference.py \
--data_type ${data_type} \
--preprocess_dir ./data/${data_type}/preprocess/ \
--max_doc_length 256 --max_query_length 32 \
--eval_batch_size 512 \
--ckpt_path ./model/${savename}/${epoch}/ \
--output_dir  evaluate/${savename}_${epoch} \
--mode train

python ./test/lightweight_ann.py \
--output_dir ./data/${data_type}/evaluate/${savename}_${epoch} \
--subvector_num 96 \
--index opq \
--topk 1000 \
--data_type ${data_type} \
--MRR_cutoff 10 \
--Recall_cutoff 5 10 30 50 100 \
--ckpt_path ./model/${savename}/${epoch}/ \
--init_index_path ./data/${data_type}/evaluate/dense_global_model_20/OPQ96,PQ96x8.index \
--mode train \
--save_hardneg_to_json

This command will save the train_hardneg.json to output_dir. Then train the dense embeddings to distinguish the ground truth from the negative in candidate:

dataset=passage
savename=dense_local_model
python train.py --model_name_or_path roberta-base \
--max_query_length 24 --max_doc_length 128 \
--data_dir ./data/${dataset}/preprocess \
--learning_rate 1e-4 --optimizer_str adamw \
--per_device_train_batch_size 128 \
--per_query_neg_num 1 \
--generate_batch_method {random_walk or snow_sample} \
--loss_method multi_ce  \
--savename ${savename} --save_model_path ./model \
--world_size 8 --gpu_rank 0_1_2_3_4_5_6_7  --master_port 13256 \
--num_train_epochs 30  \
--use_pq False \
--hardneg_json ./data/${data_type}/evaluate/sparse_global_model_20/train_hardneg.json \
--mink 0  --maxk 200 \
|tee ./log/${savename}.log

Test

data_type=passage
savename=dense_local_model
epoch=10

python ./inference.py \
--data_type ${data_type} \
--preprocess_dir ./data/${data_type}/preprocess/ \
--ckpt_path ./model/${savename}/${epoch}/ \
--max_doc_length 256 --max_query_length 32 \
--eval_batch_size 512 \
--ckpt_path ./model/${savename}/${epoch}/ \
--output_dir  evaluate/${savename}_${epoch} 

python ./test/post_verification.py \
--data_type ${data_type} \
--output_dir  evaluate/${savename}_${epoch} \
--candidate_from_ann ./data/${data_type}/evaluate/sparse_global_model_20/dev.rank_1000_score_faiss_opq.tsv \
--MRR_cutoff 10 \
--Recall_cutoff 5 10 30 50 100

Contributing

This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.

When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.

Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.

Owner
Microsoft
Open source projects and samples from Microsoft
Microsoft
This is the code of "Multi-view Contrastive Graph Clustering" in NeurlPS 2021.

MCGC Description This is the code of "Multi-view Contrastive Graph Clustering" in NeurlPS 2021. Datasets Results ACM DBLP IMDB Amazon photos Amazon co

31 Nov 14, 2022
Official PyTorch Code of GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection (CVPR 2021)

GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Mo

Abhinav Kumar 76 Jan 02, 2023
A study project using the AA-RMVSNet to reconstruct buildings from multiple images

3d-building-reconstruction This is part of a study project using the AA-RMVSNet to reconstruct buildings from multiple images. Introduction It is exci

17 Oct 17, 2022
PyTorch implementation of our ICCV 2021 paper, Interpretation of Emergent Communication in Heterogeneous Collaborative Embodied Agents.

PyTorch implementation of our ICCV 2021 paper, Interpretation of Emergent Communication in Heterogeneous Collaborative Embodied Agents.

Saim Wani 4 May 08, 2022
Compact Bilinear Pooling for PyTorch

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 07, 2022
Codes and Data Processing Files for our paper.

Code Scripts and Processing Files for EEG Sleep Staging Paper 1. Folder Tree ./src_preprocess (data preprocessing files for SHHS and Sleep EDF) sleepE

Chaoqi Yang 18 Dec 12, 2022
Official Implementation for Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation

Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation We present a generic image-to-image translation framework, pixel2style2pixel (pSp

2.8k Dec 30, 2022
Pytorch Code for "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation"

Medical-Transformer Pytorch Code for the paper "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation" About this repo: This repo

Jeya Maria Jose 615 Dec 25, 2022
Implementation of Uformer, Attention-based Unet, in Pytorch

Uformer - Pytorch Implementation of Uformer, Attention-based Unet, in Pytorch. It will only offer the concat-cross-skip connection. This repository wi

Phil Wang 72 Dec 19, 2022
Python package to add text to images, textures and different backgrounds

nider Python package for text images generation and watermarking Free software: MIT license Documentation: https://nider.readthedocs.io. nider is an a

Vladyslav Ovchynnykov 131 Dec 30, 2022
CrossNorm and SelfNorm for Generalization under Distribution Shifts (ICCV 2021)

CrossNorm (CN) and SelfNorm (SN) (Accepted at ICCV 2021) This is the official PyTorch implementation of our CNSN paper, in which we propose CrossNorm

100 Dec 28, 2022
SMPL-X: A new joint 3D model of the human body, face and hands together

SMPL-X: A new joint 3D model of the human body, face and hands together [Paper Page] [Paper] [Supp. Mat.] Table of Contents License Description News I

Vassilis Choutas 1k Jan 09, 2023
CTRMs: Learning to Construct Cooperative Timed Roadmaps for Multi-agent Path Planning in Continuous Spaces

CTRMs: Learning to Construct Cooperative Timed Roadmaps for Multi-agent Path Planning in Continuous Spaces This is a repository for the following pape

17 Oct 13, 2022
使用yolov5训练自己数据集(详细过程)并通过flask部署

使用yolov5训练自己的数据集(详细过程)并通过flask部署 依赖库 torch torchvision numpy opencv-python lxml tqdm flask pillow tensorboard matplotlib pycocotools Windows,请使用 pycoc

HB.com 19 Dec 28, 2022
SAAVN - Sound Adversarial Audio-Visual Navigation,ICLR2022 (In PyTorch)

SAAVN SAAVN Code release for paper "Sound Adversarial Audio-Visual Navigation,IC

YinfengYu 10 Aug 30, 2022
Aggragrating Nested Transformer Official Jax Implementation

NesT is a simple method, which aggragrates nested local transformers on image blocks. The idea makes vision transformers attain better accuracy, data efficiency, and convergence on the ImageNet bench

Google Research 169 Dec 20, 2022
A Light CNN for Deep Face Representation with Noisy Labels

A Light CNN for Deep Face Representation with Noisy Labels Citation If you use our models, please cite the following paper: @article{wulight, title=

Alfred Xiang Wu 715 Nov 05, 2022
Code release for DS-NeRF (Depth-supervised Neural Radiance Fields)

Depth-supervised NeRF: Fewer Views and Faster Training for Free Project | Paper | YouTube Pytorch implementation of our method for learning neural rad

524 Jan 08, 2023
Enhancing Knowledge Tracing via Adversarial Training

Enhancing Knowledge Tracing via Adversarial Training This repository contains source code for the paper "Enhancing Knowledge Tracing via Adversarial T

Xiaopeng Guo 14 Oct 24, 2022
git《USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation》(2020) GitHub: [fig2]

USD-Seg This project is an implement of paper USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation, based on FCOS detector f

Ruolin Ye 80 Nov 28, 2022