Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch

Overview

Rotary Embeddings - Pytorch

A standalone library for adding rotary embeddings to transformers in Pytorch, following its success as relative positional encoding. Specifically it will make rotating information into any axis of a tensor easy and efficient, whether they be fixed positional or learned. This library will give you state of the art results for positional embedding, at little costs.

My gut also tells me there is something more to rotations that can be exploited in artificial neural networks.

Install

$ pip install rotary-embedding-torch

Usage

import torch
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding

# instantiate the positional embedding in your transformer and pass to all your attention layers

pos_emb = RotaryEmbedding(dim = 32)

# generate the rotations

freqs = pos_emb(torch.arange(1024), cache_key = 1024) # cache with a key that is the sequence length, so that it does not need to recompute

# mock queries and keys

q = torch.randn(1, 1024, 64) # queries - (batch, seq len, dimension of head)
k = torch.randn(1, 1024, 64) # keys

# apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)

freqs = freqs[None, ...] # unsqueeze for batch dimension
q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)

# then do your attention with your queries (q) and keys (k)

If you do all the steps above correctly, you should see a dramatic improvement during training

Axial Rotary Embeddings

For easy use of 2d axial relative positional embedding, ie. vision transformers

import torch
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding, broadcat

pos_emb = RotaryEmbedding(
    dim = 32,
    freqs_for = 'pixel'
)

# queries and keys for frequencies to be rotated into

q = torch.randn(1, 256, 256, 64)
k = torch.randn(1, 256, 256, 64)

# get frequencies for each axial
# -1 to 1 has been shown to be a good choice for images and audio

freqs_h = pos_emb(torch.linspace(-1, 1, steps = 256), cache_key = 256)
freqs_w = pos_emb(torch.linspace(-1, 1, steps = 256), cache_key = 256)

# concat the frequencies along each axial
# broadcat function makes this easy without a bunch of expands

freqs = broadcat((freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim = -1)

# rotate in frequencies

q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)

Learned Rotations

For injecting learned rotations into a network. Experiments pending

Update: doesn't seem to do anything -_-, will keep trying...

import torch
from torch import nn
from rotary_embedding_torch import apply_learned_rotations

x = torch.randn(1, 1024, 512)

# you can only rotate in (dim // 2) values
# ex. for 512, you can only rotate in 256 values

# say you have two sets of learned rotations of 128 values each

rots1 = nn.Linear(512, 128)(x)
rots2 = nn.Linear(512, 128)(x)

# you rotate in 256 (128 x 2) at first

x = apply_learned_rotations(rots1, x, start_index = 0)

# then you start at index 256 and rotate in the last (128 x 2)

x = apply_learned_rotations(rots2, x, start_index = 256)

# you could also concat the rotations together and pass it in all at once

rots = torch.cat((rots1, rots2), dim = -1)

x = apply_learned_rotations(rots, x)

Citations

@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
You might also like...
Joint detection and tracking model named DEFT, or ``Detection Embeddings for Tracking.
Joint detection and tracking model named DEFT, or ``Detection Embeddings for Tracking.

DEFT: Detection Embeddings for Tracking DEFT: Detection Embeddings for Tracking, Mohamed Chaabane, Peter Zhang, J. Ross Beveridge, Stephen O'Hara

Learning embeddings for classification, retrieval and ranking.
Learning embeddings for classification, retrieval and ranking.

StarSpace StarSpace is a general-purpose neural model for efficient learning of entity embeddings for solving a wide variety of problems: Learning wor

Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation
Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation

Unseen Object Clustering: Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation Introduction In this work, we propose a new method

Improving XGBoost survival analysis with embeddings and debiased estimators
Improving XGBoost survival analysis with embeddings and debiased estimators

xgbse: XGBoost Survival Embeddings "There are two cultures in the use of statistical modeling to reach conclusions from data

State of the art Semantic Sentence Embeddings

Contrastive Tension State of the art Semantic Sentence Embeddings Published Paper · Huggingface Models · Report Bug Overview This is the official code

Reliable probability face embeddings
Reliable probability face embeddings

ProbFace, arxiv This is a demo code of training and testing [ProbFace] using Tensorflow. ProbFace is a reliable Probabilistic Face Embeddging (PFE) me

 UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus
UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus

UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus General info This is

🤖 A Python library for learning and evaluating knowledge graph embeddings
🤖 A Python library for learning and evaluating knowledge graph embeddings

PyKEEN PyKEEN (Python KnowlEdge EmbeddiNgs) is a Python package designed to train and evaluate knowledge graph embedding models (incorporating multi-m

Large scale embeddings on a single machine.

Marius Marius is a system under active development for training embeddings for large-scale graphs on a single machine. Training on large scale graphs

Comments
  • Custom position offset when rotating queries or keys

    Custom position offset when rotating queries or keys

    This library seems to assume that queries and keys are left-aligned position-wise e.g.

    q = [p_0, p_1, p_2]
    k = [p_0, p_1, p_2, p_3, p_4]
    

    where p_i are corresponding positions. This is enforced by starting the sequence of positions always from 0 with torch.arange(seq_len) here. Applications like Perceiver AR, however, require a position-wise right-alignment e.g.

    q =           [p_2, p_3, p_4]
    k = [p_0, p_1, p_2, p_3, p_4]
    

    This pull requests allows to specify a start position for queries and or keys to enable alignments other than left-alignments. For example

    import torch
    from rotary_embedding_torch.rotary_embedding_torch import RotaryEmbedding
    
    rot = RotaryEmbedding(dim=32)
    
    q = torch.ones(1, 8, 4, 32)
    k = torch.ones(1, 8, 6, 32)
    
    q = q / torch.norm(q, dim=-1, keepdim=True)
    k = k / torch.norm(k, dim=-1, keepdim=True)
    
    q_rot = rot.rotate_queries_or_keys(q, start_pos=k.shape[2] - q.shape[2])
    k_rot = rot.rotate_queries_or_keys(k)
    
    attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
    print(attn[0, 0])
    

    prints the following relative position embedding

    tensor([[0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
            [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581],
            [0.7288, 0.7670, 0.8581, 0.9571, 1.0000, 0.9571],
            [0.7361, 0.7288, 0.7670, 0.8581, 0.9571, 1.0000]])
    

    (diagonal of 1s right-aligned) whereas the default behavior

    ...
    
    q_rot = rot.rotate_queries_or_keys(q)
    k_rot = rot.rotate_queries_or_keys(k)
    
    attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
    print(attn[0, 0])
    

    would print

    tensor([[1.0000, 0.9571, 0.8581, 0.7670, 0.7288, 0.7361],
            [0.9571, 1.0000, 0.9571, 0.8581, 0.7670, 0.7288],
            [0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
            [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581]])
    

    (diagonal of 1s left-aligned).

    opened by krasserm 1
  • about axial rotary embeddings

    about axial rotary embeddings

    Hi, Thank you for sharing this code with us. However, I was confused with the axial rotary embeddings in rotary_embedding_torch.py file. " elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi " Where does this formula come from?What parameter is max_freqs?Why the freqs is not " 1/(10000^(2i/d))"?

    Thank you again.

    opened by raindrop313 0
Owner
Phil Wang
Working with Attention
Phil Wang
(ICCV'21) Official PyTorch implementation of Relational Embedding for Few-Shot Classification

Relational Embedding for Few-Shot Classification (ICCV 2021) Dahyun Kang, Heeseung Kwon, Juhong Min, Minsu Cho [paper], [project hompage] We propose t

Dahyun Kang 82 Dec 24, 2022
Author Disambiguation using Knowledge Graph Embeddings with Literals

Author Name Disambiguation with Knowledge Graph Embeddings using Literals This is the repository for the master thesis project on Knowledge Graph Embe

12 Oct 19, 2022
FusionNet: A deep fully residual convolutional neural network for image segmentation in connectomics

FusionNet_Pytorch FusionNet: A deep fully residual convolutional neural network for image segmentation in connectomics Requirements Pytorch 0.1.11 Pyt

Choi Gunho 102 Dec 13, 2022
Code repository for the work "Multi-Domain Incremental Learning for Semantic Segmentation", accepted at WACV 2022

Multi-Domain Incremental Learning for Semantic Segmentation This is the Pytorch implementation of our work "Multi-Domain Incremental Learning for Sema

Pgxo20 24 Jan 02, 2023
An onlinel learning to rank python codebase.

OLTR Online learning to rank python codebase. The code related to Pairwise Differentiable Gradient Descent (ranker/PDGDLinearRanker.py) is copied from

ielab 5 Jul 18, 2022
🕹️ Official Implementation of Conditional Motion In-betweening (CMIB) 🏃

Conditional Motion In-Betweening (CMIB) Official implementation of paper: Conditional Motion In-betweeening. Paper(arXiv) | Project Page | YouTube in-

Jihoon Kim 81 Dec 22, 2022
Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis

Putting NeRF on a Diet: Semantically Consistent Few-Shot View Synthesis Website | ICCV paper | arXiv | Twitter This repository contains the official i

Ajay Jain 73 Dec 27, 2022
Trainable Bilateral Filter Layer (PyTorch)

Trainable Bilateral Filter Layer (PyTorch) This repository contains our GPU-accelerated trainable bilateral filter layer (three spatial and one range

FabianWagner 26 Dec 25, 2022
DecoupledNet is semantic segmentation system which using heterogeneous annotations

DecoupledNet: Decoupled Deep Neural Network for Semi-supervised Semantic Segmentation Created by Seunghoon Hong, Hyeonwoo Noh and Bohyung Han at POSTE

Hyeonwoo Noh 74 Sep 22, 2021
Official code for our ICCV paper: "From Continuity to Editability: Inverting GANs with Consecutive Images"

GANInversion_with_ConsecutiveImgs Official code for our ICCV paper: "From Continuity to Editability: Inverting GANs with Consecutive Images" https://a

QingyangXu 38 Dec 07, 2022
Fully Convolutional Networks for Semantic Segmentation by Jonathan Long*, Evan Shelhamer*, and Trevor Darrell. CVPR 2015 and PAMI 2016.

Fully Convolutional Networks for Semantic Segmentation This is the reference implementation of the models and code for the fully convolutional network

Evan Shelhamer 3.2k Jan 08, 2023
PyTorch implementation of "Simple and Deep Graph Convolutional Networks"

Simple and Deep Graph Convolutional Networks This repository contains a PyTorch implementation of "Simple and Deep Graph Convolutional Networks".(http

chenm 253 Dec 08, 2022
SynNet - synthetic tree generation using neural networks

SynNet This repo contains the code and analysis scripts for our amortized approach to synthetic tree generation using neural networks. Our model can s

Wenhao Gao 60 Dec 29, 2022
K-Nearest Neighbor in Pytorch

Pytorch KNN CUDA 2019/11/02 This repository will no longer be maintained as pytorch supports sort() and kthvalue on tensors. git clone https://github.

Chris Choy 65 Dec 01, 2022
Code release for ICCV 2021 paper "Anticipative Video Transformer"

Anticipative Video Transformer Ranked first in the Action Anticipation task of the CVPR 2021 EPIC-Kitchens Challenge! (entry: AVT-FB-UT) [project page

Facebook Research 123 Dec 13, 2022
mmfewshot is an open source few shot learning toolbox based on PyTorch

OpenMMLab FewShot Learning Toolbox and Benchmark

OpenMMLab 514 Dec 28, 2022
Byzantine-robust decentralized learning via self-centered clipping

Byzantine-robust decentralized learning via self-centered clipping In this paper, we study the challenging task of Byzantine-robust decentralized trai

EPFL Machine Learning and Optimization Laboratory 4 Aug 27, 2022
Volumetric Correspondence Networks for Optical Flow, NeurIPS 2019.

VCN: Volumetric correspondence networks for optical flow [project website] Requirements python 3.6 pytorch 1.1.0-1.3.0 pytorch correlation module (opt

Gengshan Yang 144 Dec 06, 2022
Package for extracting emotions from social media text. Tailored for financial data.

EmTract: Extracting Emotions from Social Media Text Tailored for Financial Contexts EmTract is a tool that extracts emotions from social media text. I

13 Nov 17, 2022
The reference baseline of final exam for XMU machine learning course

Mini-NICO Baseline The baseline is a reference method for the final exam of machine learning course. Requirements Installation we use /python3.7 /torc

JoaquinChou 3 Dec 29, 2021