Rax is a Learning-to-Rank library written in JAX

Related tags

Deep Learningrax
Overview

🦖 Rax: Composable Learning to Rank using JAX

Rax is a Learning-to-Rank library written in JAX. Rax provides off-the-shelf implementations of ranking losses and metrics to be used with JAX. It provides the following functionality:

  • Ranking losses (rax.*_loss): rax.softmax_loss, rax.pairwise_logistic_loss, ...
  • Ranking metrics (rax.*_metric): rax.mrr_metric, rax.ndcg_metric, ...
  • Transformations (rax.*_t12n): rax.approx_t12n, rax.gumbel_t12n, ...

Ranking

A ranking problem is different from traditional classification/regression problems in that its objective is to optimize for the correctness of the relative order of a list of examples (e.g., documents) for a given context (e.g., a query). Rax provides support for ranking problems within the JAX ecosystem. It can be used in, but is not limited to, the following applications:

  • Search: ranking a list of documents with respect to a query.
  • Recommendation: ranking a list of items given a user as context.
  • Question Answering: finding the best answer from a list of candidates.
  • Dialogue System: finding the best response from a list of responses.

Synopsis

In a nutshell, given the scores and labels for a list of items, Rax can compute various ranking losses and metrics:

import jax.numpy as jnp
import rax

scores = jnp.asarray([2.2, -1.3, 5.4])  # output of a model.
labels = jnp.asarray([1., 0., 0.])      # indicates doc 1 is relevant.

rax.ndcg_metric(scores, labels)         # computes a ranking metric.
rax.pairwise_hinge_loss(scores, labels) # computes a ranking loss.

All of the Rax losses and metrics are purely functional and compose well with standard JAX transformations. Additionally, Rax provides ranking-specific transformations so you can build new ranking losses. An example is rax.approx_t12n, which can be used to transform any (non-differentiable) ranking metric into a differentiable loss. For example:

loss_fn = rax.approx_t12n(rax.ndcg_metric)
loss_fn(scores, labels)            # differentiable approx ndcg loss.
jax.grad(loss_fn)(scores, labels)  # computes gradients w.r.t. scores.

Examples

See the examples/ directory for complete examples on how to use Rax.

You might also like...
[ICLR 2021] Rank the Episodes: A Simple Approach for Exploration in Procedurally-Generated Environments.
[ICLR 2021] Rank the Episodes: A Simple Approach for Exploration in Procedurally-Generated Environments.

[ICLR 2021] RAPID: A Simple Approach for Exploration in Reinforcement Learning This is the Tensorflow implementation of ICLR 2021 paper Rank the Episo

Rank 1st in the public leaderboard of ScanRefer (2021-03-18)
Rank 1st in the public leaderboard of ScanRefer (2021-03-18)

InstanceRefer InstanceRefer: Cooperative Holistic Understanding for Visual Grounding on Point Clouds through Instance Multi-level Contextual Referring

Code for
Code for "LoRA: Low-Rank Adaptation of Large Language Models"

LoRA: Low-Rank Adaptation of Large Language Models This repo contains the implementation of LoRA in GPT-2 and steps to replicate the results in our re

Official PyTorch Implementation of Rank & Sort Loss [ICCV2021]
Official PyTorch Implementation of Rank & Sort Loss [ICCV2021]

Rank & Sort Loss for Object Detection and Instance Segmentation The official implementation of Rank & Sort Loss. Our implementation is based on mmdete

This is the pytorch implementation for the paper: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation, which is accepted to ICCV2021.

GMPQ: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation This is the pytorch implementation for the paper: Generalizable Mix

 COD-Rank-Localize-and-Segment (CVPR2021)
COD-Rank-Localize-and-Segment (CVPR2021)

COD-Rank-Localize-and-Segment (CVPR2021) Simultaneously Localize, Segment and Rank the Camouflaged Objects Full camouflage fixation training dataset i

Gradient-free global optimization algorithm for multidimensional functions based on the low rank tensor train format

ttopt Description Gradient-free global optimization algorithm for multidimensional functions based on the low rank tensor train (TT) format and maximu

ViViT: Curvature access through the generalized Gauss-Newton's low-rank structure

ViViT is a collection of numerical tricks to efficiently access curvature from the generalized Gauss-Newton (GGN) matrix based on its low-rank structure. Provided functionality includes computing

This is the solution for 2nd rank in Kaggle competition: Feedback Prize - Evaluating Student Writing.

Feedback Prize - Evaluating Student Writing This is the solution for 2nd rank in Kaggle competition: Feedback Prize - Evaluating Student Writing. The

Comments
  • Add Kendall's tau

    Add Kendall's tau

    Kendall's tau (https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient) is a ranking metric which we could add to Rax. We can reference the SciPy version of Kendall's tau (https://github.com/scipy/scipy/blob/v1.9.0/scipy/stats/_stats_py.py#L5015-L5222) as inspiration or comparison for the Rax implementation.

    Expected usage:

    scores = jnp.array([0., 1., 2.])
    labels = jnp.array([1., 2., 0.])
    rax.kendalltau_metric(scores, labels)  # should produce Kendall's tau-b.
    
    enhancement 
    opened by rjagerman 0
Releases(v0.2.0)
Owner
Google
Google ❤️ Open Source
Google
Tree-based Search Graph for Approximate Nearest Neighbor Search

TBSG: Tree-based Search Graph for Approximate Nearest Neighbor Search. TBSG is a graph-based algorithm for ANNS based on Cover Tree, which is also an

Fanxbin 2 Dec 27, 2022
Si Adek Keras is software VR dangerous object detection.

Si Adek Python Keras Sistem Informasi Deteksi Benda Berbahaya Keras Python. Version 1.0 Developed by Ananda Rauf Maududi. Developed date: 24 November

Ananda Rauf 1 Dec 21, 2021
ViewFormer: NeRF-free Neural Rendering from Few Images Using Transformers

ViewFormer: NeRF-free Neural Rendering from Few Images Using Transformers Official implementation of ViewFormer. ViewFormer is a NeRF-free neural rend

Jonáš Kulhánek 169 Dec 30, 2022
Cowsay - A rewrite of cowsay in python

Python Cowsay A rewrite of cowsay in python. Allows for parsing of existing .cow

James Ansley 3 Jun 27, 2022
[NeurIPS 2021] “Improving Contrastive Learning on Imbalanced Data via Open-World Sampling”,

Improving Contrastive Learning on Imbalanced Data via Open-World Sampling Introduction Contrastive learning approaches have achieved great success in

VITA 24 Dec 17, 2022
Editing a classifier by rewriting its prediction rules

This repository contains the code and data for our paper: Editing a classifier by rewriting its prediction rules Shibani Santurkar*, Dimitris Tsipras*

Madry Lab 86 Dec 27, 2022
Project dự đoán giá cổ phiếu bằng thuật toán LSTM gồm: code train và code demo

Web predicts stock prices using Long - Short Term Memory algorithm Give me some start please!!! User interface image: Choose: DayBegin, DayEnd, Stock

Vo Thuong Truong Nhon 8 Nov 11, 2022
Cross-Task Consistency Learning Framework for Multi-Task Learning

Cross-Task Consistency Learning Framework for Multi-Task Learning Tested on numpy(v1.19.1) opencv-python(v4.4.0.42) torch(v1.7.0) torchvision(v0.8.0)

Aki Nakano 2 Jan 08, 2022
Code repository for our paper "Learning to Generate Scene Graph from Natural Language Supervision" in ICCV 2021

Scene Graph Generation from Natural Language Supervision This repository includes the Pytorch code for our paper "Learning to Generate Scene Graph fro

Yiwu Zhong 64 Dec 24, 2022
DenseCLIP: Language-Guided Dense Prediction with Context-Aware Prompting

DenseCLIP: Language-Guided Dense Prediction with Context-Aware Prompting Created by Yongming Rao*, Wenliang Zhao*, Guangyi Chen, Yansong Tang, Zheng Z

Yongming Rao 322 Dec 31, 2022
ICCV2021: Code for 'Spatial Uncertainty-Aware Semi-Supervised Crowd Counting'

ICCV2021: Code for 'Spatial Uncertainty-Aware Semi-Supervised Crowd Counting'

Yanda Meng 14 May 13, 2022
SporeAgent: Reinforced Scene-level Plausibility for Object Pose Refinement

SporeAgent: Reinforced Scene-level Plausibility for Object Pose Refinement This repository implements the approach described in SporeAgent: Reinforced

Dominik Bauer 5 Jan 02, 2023
MoViNets PyTorch implementation: Mobile Video Networks for Efficient Video Recognition;

MoViNet-pytorch Pytorch unofficial implementation of MoViNets: Mobile Video Networks for Efficient Video Recognition. Authors: Dan Kondratyuk, Liangzh

189 Dec 20, 2022
PyTorch implementation of Weak-shot Fine-grained Classification via Similarity Transfer

SimTrans-Weak-Shot-Classification This repository contains the official PyTorch implementation of the following paper: Weak-shot Fine-grained Classifi

BCMI 60 Dec 02, 2022
This repository contains an implementation of ConvMixer for the ICLR 2022 submission "Patches Are All You Need?".

Patches Are All You Need? 🤷 This repository contains an implementation of ConvMixer for the ICLR 2022 submission "Patches Are All You Need?". Code ov

ICLR 2022 Author 934 Dec 30, 2022
QQ Browser 2021 AI Algorithm Competition Track 1 1st Place Program

QQ Browser 2021 AI Algorithm Competition Track 1 1st Place Program

249 Jan 03, 2023
This repository contains code accompanying the paper "An End-to-End Chinese Text Normalization Model based on Rule-Guided Flat-Lattice Transformer"

FlatTN This repository contains code accompanying the paper "An End-to-End Chinese Text Normalization Model based on Rule-Guided Flat-Lattice Transfor

THUHCSI 74 Nov 28, 2022
[CVPR 2022] Semi-Supervised Semantic Segmentation Using Unreliable Pseudo-Labels

Using Unreliable Pseudo Labels Official PyTorch implementation of Semi-Supervised Semantic Segmentation Using Unreliable Pseudo Labels, CVPR 2022. Ple

Haochen Wang 268 Dec 24, 2022
[NeurIPS 2021] Shape from Blur: Recovering Textured 3D Shape and Motion of Fast Moving Objects

[NeurIPS 2021] Shape from Blur: Recovering Textured 3D Shape and Motion of Fast Moving Objects YouTube | arXiv Prerequisites Kaolin is available here:

Denys Rozumnyi 107 Dec 26, 2022
DCGAN LSGAN WGAN-GP DRAGAN PyTorch

Recommendation Our GAN based work for facial attribute editing - AttGAN. News 8 April 2019: We re-implement these GANs by Tensorflow 2! The old versio

Zhenliang He 408 Nov 30, 2022