Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention

Overview

Sinkhorn Transformer

PyPI version


This is a reproduction of the work outlined in Sparse Sinkhorn Attention, with additional enhancements.

It includes a parameterized sorting network, using sinkhorn normalization to sample a permutation matrix that matches the most relevant buckets of keys to the buckets of queries.

This work also brings in reversible networks and feed forward chunking (concepts introduced from Reformer) to bring about further memory savings.

Open In Colab 204k tokens (demonstration purposes)

Install

$ pip install sinkhorn_transformer

Use

A Sinkhorn Transformer based language model

import torch
from sinkhorn_transformer import SinkhornTransformerLM

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    max_seq_len = 8192,
    bucket_size = 128,        # size of the buckets
    causal = False,           # auto-regressive or not
    n_sortcut = 2,            # use sortcut to reduce memory complexity to linear
    n_top_buckets = 2,        # sort specified number of key/value buckets to one query bucket. paper is at 1, defaults to 2
    ff_chunks = 10,           # feedforward chunking, from Reformer paper
    reversible = True,        # make network reversible, from Reformer paper
    emb_dropout = 0.1,        # embedding dropout
    ff_dropout = 0.1,         # feedforward dropout
    attn_dropout = 0.1,       # post attention dropout
    attn_layer_dropout = 0.1, # post attention layer dropout
    layer_dropout = 0.1,      # add layer dropout, from 'Reducing Transformer Depth on Demand' paper
    weight_tie = True,        # tie layer parameters, from Albert paper
    emb_dim = 128,            # embedding factorization, from Albert paper
    dim_head = 64,            # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
    ff_glu = True,            # use GLU in feedforward, from paper 'GLU Variants Improve Transformer'
    n_local_attn_heads = 2,   # replace N heads with local attention, suggested to work well from Routing Transformer paper
    pkm_layers = (4,7),       # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
    pkm_num_keys = 128,       # defaults to 128, but can be increased to 256 or 512 as memory allows
)

x = torch.randint(0, 20000, (1, 2048))
model(x) # (1, 2048, 20000)

A plain Sinkhorn Transformer, layers of sinkhorn attention

import torch
from sinkhorn_transformer import SinkhornTransformer

model = SinkhornTransformer(
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 128
)

x = torch.randn(1, 2048, 1024)
model(x) # (1, 2048, 1024)

Sinkhorn Encoder / Decoder Transformer

import torch
from sinkhorn_transformer import SinkhornTransformerLM

DE_SEQ_LEN = 4096
EN_SEQ_LEN = 4096

enc = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    heads = 8,
    bucket_size = 128,
    max_seq_len = DE_SEQ_LEN,
    reversible = True,
    return_embeddings = True
).cuda()

dec = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    causal = True,
    bucket_size = 128,
    max_seq_len = EN_SEQ_LEN,
    receives_context = True,
    context_bucket_size = 128,  # context key / values can be bucketed differently
    reversible = True
).cuda()

x = torch.randint(0, 20000, (1, DE_SEQ_LEN)).cuda()
y = torch.randint(0, 20000, (1, EN_SEQ_LEN)).cuda()

x_mask = torch.ones_like(x).bool().cuda()
y_mask = torch.ones_like(y).bool().cuda()

context = enc(x, input_mask=x_mask)
dec(y, context=context, input_mask=y_mask, context_mask=x_mask) # (1, 4096, 20000)

Autopadder

By default the model will complain if given an input that is not a multiple of the bucket size. To avoid having to make the same padding calculations each time, you can use the helper Autopadder class. It will take care of the input_mask for you as well, if given. Contextual key/values and mask are supported as well.

import torch
from sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer import Autopadder

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    max_seq_len = 2048,
    bucket_size = 128,
    causal = True
)

model = Autopadder(model, pad_left=True) # autopadder will fetch the bucket size and autopad input

x = torch.randint(0, 20000, (1, 1117)) # odd sequence length
model(x) # (1, 1117, 20000)

Sinkhorn

This repository has diverged from the paper and is now using attention in place of the original sorting net + gumbel sinkhorn sampling. I have not found a noticeable difference in performance yet, and the new scheme allows me to generalize the network to flexible sequence lengths. If you would like to try Sinkhorn, please use the following settings, which only works for non-causal networks.

import torch
from sinkhorn_transformer import SinkhornTransformerLM

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 128,
    max_seq_len = 8192,
    use_simple_sort_net = True, # turn off attention sort net
    sinkhorn_iter = 7,          # number of sinkhorn iterations - default is set at reported best in paper
    n_sortcut = 2,              # use sortcut to reduce complexity to linear time
    temperature = 0.75,         # gumbel temperature - default is set at reported best in paper
    non_permutative = False,    # allow buckets of keys to be sorted to queries more than once
)

x = torch.randint(0, 20000, (1, 8192))
model(x) # (1, 8192, 20000)

Product Key Memory

To see the benefits of using PKM, the learning rate of the values must be set higher than the rest of the parameters. (Recommended to be 1e-2)

You can follow the instructions here to set it correctly https://github.com/lucidrains/product-key-memory#learning-rates

Issues

Decoding and sequence lengths

Sinkhorn, when trained on fixed length sequences, seems to have trouble decoding sequences from scratch, mainly due to the fact that the sorting net has trouble generalizing when the buckets are partially filled with padding tokens.

Fortunately, I think I have found a simple solution. During training, for causal networks, randomly truncate the sequences and force the sorting net to generalize. I have provided a flag (randomly_truncate_sequence) for the AutoregressiveWrapper instance to make this easy.

import torch
from sinkhorn_transformer import SinkhornTransformerLM, AutoregressiveWrapper

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    bucket_size = 75,
    max_seq_len = 8192,
    causal = True
)

model = AutoregressiveWrapper(model)

x = torch.randint(0, 20000, (1, 8192))
loss = model(x, return_loss = True, randomly_truncate_sequence = True) # (1, 8192, 20000)

I am open to suggestions if someone has found a better solution.

Causal sorting net

There is a potential problem with the causal sorting network, where the decision of which key/value buckets of the past sorts to a bucket is dependent only on the first token and not the rest (due to the bucketing scheme and preventing leakage of future to past).

I have attempted to alleviate this problem by rotating half the heads to the left by bucket size - 1, thereby promoting the last token to be first. This is also the reason why the AutoregressiveWrapper defaults to left padding during training, to always make sure that the last token in the sequence have a say in what to retrieve.

If anyone has found a cleaner solution, please let me know in the issues.

Alternatives

  1. Routing Transformer - https://github.com/lucidrains/routing-transformer
  2. Reformer - https://github.com/lucidrains/reformer-pytorch

Citations

@misc{tay2020sparse,
    title   = {Sparse Sinkhorn Attention},
    author  = {Yi Tay and Dara Bahri and Liu Yang and Donald Metzler and Da-Cheng Juan},
    year    = {2020},
    url.    = {https://arxiv.org/abs/2002.11296}
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://openreview.net/forum?id=B1gjs6EtDr}
}
@inproceedings{fan2020reducing,
    title     ={Reducing Transformer Depth on Demand with Structured Dropout},
    author    ={Angela Fan and Edouard Grave and Armand Joulin},
    booktitle ={International Conference on Learning Representations},
    year      ={2020},
    url       ={https://openreview.net/forum?id=SylO2yStDr}
}
@misc{lample2019large,
    title   = {Large Memory Layers with Product Keys},
    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
    year    = {2019},
    eprint  = {1907.05242},
    archivePrefix = {arXiv}
}
@misc{bhojanapalli2020lowrank,
    title   = {Low-Rank Bottleneck in Multi-head Attention Models},
    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
    year    = {2020},
    eprint  = {2002.07028}
}

Comments
  • Training falling on version 0.0.14 and 0.0.15

    Training falling on version 0.0.14 and 0.0.15

    Hi, I testing training model on new versions of repo, and I have some troubles with 0.0.14 and 0.0.15. On 0.0.14, model always return nan on forward pass, version 0.0.15 lead to CUDA error:

    RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
    

    Full error listing:

    ipython-input-7-1329da5363de> in forward(self, inputs, labels)
          7   def forward(self, inputs, labels=None):
          8     loss_mx = labels != -100
    ----> 9     output = self.model(inputs)
         10     output = output[loss_mx].view(-1, tokenizer.vocab_size)
         11     labels = labels[loss_mx].view(-1)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
        376         x = self.to_token_emb(x)
        377         x = self.pos_emb(torch.arange(t, device=device)) + x
    --> 378         x = self.sinkhorn_transformer(x)
        379         return self.to_logits(x)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, input_mask)
        359 
        360     def forward(self, x, input_mask = None):
    --> 361         return self.layers(x)
        362 
        363 class SinkhornTransformerLM(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
        330     def forward(self, x, **kwargs):
        331         x = torch.cat([x, x], dim=-1)
    --> 332         x = self.layers(x, **kwargs)
        333         return torch.stack(x.chunk(2, dim=-1)).sum(dim=0)
        334 
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, arg_route, **kwargs)
        128         block_kwargs = {'f_args': f_args, 'g_args': g_args}
        129 
    --> 130         return _ReversibleFunction.apply(x, blocks, block_kwargs)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(ctx, x, blocks, kwargs)
         98         ctx.kwargs = kwargs
         99         for block in blocks:
    --> 100             x = block(x, **kwargs)
        101         ctx.y = x.detach()
        102         ctx.blocks = blocks
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, x, f_args, g_args)
         51         with torch.no_grad():
         52             y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
    ---> 53             y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
         54 
         55         return torch.cat([y1, y2], dim=2)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/reversible.py in forward(self, record_rng, set_rng, *args, **kwargs)
         25 
         26         if not set_rng:
    ---> 27             return self.net(*args, **kwargs)
         28 
         29         rng_devices = []
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x)
         91     def forward(self, x):
         92         chunks = x.chunk(self.chunks, dim = self.dim)
    ---> 93         return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
         94 
         95 class FeedForward(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in <listcomp>(.0)
         91     def forward(self, x):
         92         chunks = x.chunk(self.chunks, dim = self.dim)
    ---> 93         return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
         94 
         95 class FeedForward(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x, **kwargs)
        112     def forward(self, x, **kwargs):
        113         x = self.norm(x)
    --> 114         return self.fn(x, **kwargs)
        115 
        116 class SortNet(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py in forward(self, x)
        103 
        104     def forward(self, x):
    --> 105         return self.net(x)
        106 
        107 class PreNorm(nn.Module):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
         98     def forward(self, input):
         99         for module in self:
    --> 100             input = module(input)
        101         return input
        102 
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
        548             result = self._slow_forward(*input, **kwargs)
        549         else:
    --> 550             result = self.forward(*input, **kwargs)
        551         for hook in self._forward_hooks.values():
        552             hook_result = hook(self, input, result)
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
         85 
         86     def forward(self, input):
    ---> 87         return F.linear(input, self.weight, self.bias)
         88 
         89     def extra_repr(self):
    
    /opt/anaconda/envs/torch-nigtly-reformer/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
       1591         ret = torch.addmm(bias, input, weight.t())
       1592     else:
    -> 1593         output = input.matmul(weight.t())
       1594         if bias is not None:
       1595             output += bias
    
    RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`
    

    Also, version 0.0.11(and all other version from 0.0.8) work stable.

    opened by blizda 30
  • generation problem in a toy task

    generation problem in a toy task

    Here is the full script for my toy task (x -> xx like "abc" to "abcabc")

    from sinkhorn_transformer import SinkhornTransformerLM
    from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper
    
    import random
    import tqdm
    import gzip
    import numpy as np
    import torch
    import torch.optim as optim
    from torch import nn
    from torch.nn import functional as F
    from torch.utils.data import DataLoader, Dataset
    
    # constants
    
    NUM_BATCHES = int(1e5)
    BATCH_SIZE = 4
    GRADIENT_ACCUMULATE_EVERY = 4
    LEARNING_RATE = 1e-4
    VALIDATE_EVERY  = 100
    GENERATE_EVERY  = 100
    ENC_SEQ_LEN=16
    DEC_SEQ_LEN=40
    NUM_TOKENS = 256 + 2
    BUCKET_SIZE = 8
    
    # helpers
    
    def top_k(logits, thres = 0.9):
        k = int((1 - thres) * logits.shape[-1])
        val, ind = torch.topk(logits, k)
        probs = torch.full_like(logits, float('-inf'))
        probs.scatter_(1, ind, val)
        return probs
    
    
    def cycle():
        while True:
            source = torch.randint(2, 258, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()
    
            target = torch.cat((source, source), 1)
            prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
            target = torch.cat((prefix, target), axis=1)
    
            x_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool().cuda()
            y_mask = torch.ones(BATCH_SIZE, target.shape[1]).bool().cuda()
    
    
            yield (source, target, x_mask, y_mask)
    
    # instantiate model
    
    class MySinkhornTransformer(nn.Module):
        def __init__(self, num_tokens, dim, depth, heads, bucket_size, enc_max_seq_len, dec_max_seq_len):
            super().__init__()
            
            self.pad_token = 0
            self.sos_token = 1
    
            self.enc = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, bucket_size=bucket_size, max_seq_len=enc_max_seq_len,
                                             reversible=True, return_embeddings=True)
            self.dec = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, causal=True, bucket_size=bucket_size, max_seq_len=dec_max_seq_len, 
                                             receives_context=True, context_bucket_size=bucket_size, reversible=True)
            self.dec = AutoregressiveWrapper(self.dec, pad_value=num_tokens-2)
        
        @torch.no_grad()
        def generate(self, x, x_mask):
            context = self.enc(x, input_mask=x_mask)
            start_tokens = (torch.ones((x.shape[0],1)) * self.sos_token).long().cuda()
    
            return self.dec.generate(start_tokens, 32, context=context, context_mask=x_mask)
    
        def forward(self, x, y, x_mask, y_mask, return_loss):
            context = self.enc(x, input_mask=x_mask)
            return self.dec(y, context=context, input_mask=y_mask, context_mask=x_mask, return_loss=True)
    
    
    model = MySinkhornTransformer(num_tokens=NUM_TOKENS, dim=512, depth=1, heads=1, bucket_size=BUCKET_SIZE, enc_max_seq_len=ENC_SEQ_LEN, dec_max_seq_len=DEC_SEQ_LEN)
    model.cuda()
    # optimizer
    
    optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # training
    
    for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
        model.train()
    
        for __ in range(GRADIENT_ACCUMULATE_EVERY):
            source, target, x_mask, y_mask = next(cycle())
            loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
            loss.backward()
    
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optim.step()
        optim.zero_grad()
    
        if i % VALIDATE_EVERY == 0:
            model.eval()
            with torch.no_grad():
                source, target, x_mask, y_mask = next(cycle())
                loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
                print(f'validation loss: {loss.item()}')
    
        if i % GENERATE_EVERY == 0:
            model.eval()
            
            source, target, x_mask, y_mask = next(cycle())
            
            sample = model.generate(x=source, x_mask=x_mask)
            print("input:  ", source[0])
            print("model output:  ", sample[0])
    

    After a few steps the loss becomes practically zero. I checked the logits during the training and they seem to be OK. but during generation phase, the model outputs this pattern: "x,x,x,x,x,y,y,y,y,y" like "aaaabbbb" instead of "abcdabcd". I was wondering what might be the underlying issue. Do you got any idea?

    opened by py4 18
  • Sortcut variant and bucket size

    Sortcut variant and bucket size

    From the paper it looks to me that for the sortcut variant, the queries are allowed to attend to all the key buckets post truncation for the non-causal case. If this is correct, won't the output be the same no matter what the bucket size is for the query.

    Based on my understanding of the paper, the authors only report results selecting 2 top key/value blocks with bucket sizes 8,16,32. They do not mention having different bucket sizes for query and key/value which I found in this repo.

    opened by jsmith1915 16
  • A noobish question about training...

    A noobish question about training...

    @lucidrains

    First of all, everything seems to be working now in Google Colab so thank you very much for fixing it.

    I have a quick question about training if you do not mind...

    I get very high loss results. Here is the example:

    training: 1%| | 509/100000 [1:55:22<387:23:57, 14.02s/it]training loss: 2.495532512664795

    Is this normal and I simply need to train more? Or does it mean that there is a problem somewhere?

    2.49 in 2 hours is way too much IMHO. I am pretty sure I am not doing it right so your advice would be really apppreciated.

    Thank you.

    P.S. I am running your wiki8 example in Google Colab Pro.

    opened by asigalov61 9
  • several question about implementation

    several question about implementation

    I am currently trying to implement Sparse Sinkhorn attention (non-causal, self-attention, pretrain for MLM task) with tensorflow and I would appreciate it if you could answer several questions about your code.

    1. I did not quite understand from the paper what is happening with queries and keys in SortCut, when we cut off most of the blocks. In your code it seems like you broadcast keys/queries like this [number_blocks, block_len, size_per_head] -> [:n_sortcut, block_len, size_per_head] -> [1, n_sortcut * block_len, size_per_head] -> [number_blocks, n_sortcut * block_len, size_per_head] I understand that pytorch handling expand_dim automatically and memory is preserved, but I do not understand how does the memory consumption stay linear: the resulting attention scores (dots) are still [number_blocks, block_len, num_sortcut * block_len] which is good but quadratic. Do I miss something?

    2. In the case of SortCut what is the meaning of the softmax(A)*Q? For example, for full attention outputs are vectors which are weighted average of other vectors; For local block attention outputs are vectors that are weighted average of other vectors in the block. How would you interpret the output of SortCut attention layer?

    3. In your code you concatenate regular keys and values with permuted/sorted keys and values, but in the paper (3.2. part) it seems like simple addition. Why is it different from the paper? Actually I can understand concatenation more than summation (because I am confused which attention masks to use in this case: from the regular block of permuted).

    4. Seems like you are using concatenated queries and keys as an input to make a SortNet. Maybe I completely missed the point, but I did not find anything about this in the article: the input sequence is used for SortNet instead of its projections. In fact, for some reason, I cannot make it converge if I use query and keys (maybe for a completely unrelated reasons). Have you tried to compare input sequence vs concat(q,k)?

    opened by w4-magnes 9
  • TypeError exception in AxialPositionalEncoding when using DataParallel

    TypeError exception in AxialPositionalEncoding when using DataParallel

    Hello,

    I want to run SinkhornTransformerLM using multiple GPUs, so I'm wrapping the model into torch.nn.DataParallel. However, when I do this, I get an exception:

    Traceback (most recent call last):
      File "script.py", line 27, in <module>
        model(x)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 155, in forward
        outputs = self.parallel_apply(replicas, inputs, kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 165, in parallel_apply
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
        output.reraise()
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/_utils.py", line 395, in reraise
        raise self.exc_type(msg)
    TypeError: Caught TypeError in replica 0 on device 0.
    Original Traceback (most recent call last):
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
        output = module(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 792, in forward
        x = self.axial_pos_emb(x) + x
      File "/home/ubuntu/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/ubuntu/.local/lib/python3.6/site-packages/sinkhorn_transformer/sinkhorn_transformer.py", line 243, in forward
        return pos_emb[:, :t]
    TypeError: 'int' object is not subscriptable
    

    Looking at the code, it would seem that self.weights does not get populated. To reproduce this error, I took the first example in README.md and changed

    model(x) # (1, 2048, 20000)
    

    to

    model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))).to('cuda')
    model(x)
    
    opened by kl0211 8
  • Do I need to pad?

    Do I need to pad?

    Excuse me for the noob question, I have sequences of different lengths and I use:

    model = AutoregressiveWrapper(model, ignore_index=PAD_ID, pad_value=PAD_ID)
    

    Do I need to pad as in this example or not?

    def __getitem__(self, index):
            seq_tokens = self.examples[index]
            input_ids = torch.tensor(seq_tokens, dtype=torch.long)
            input_ids = F.pad(input_ids, (seq_len - len(input_ids), 0), value=self.pad_value)
            return input_ids
    
    opened by timsoraro 8
  • Performance of the different attention variants in the repo

    Performance of the different attention variants in the repo

    Really great work!!

    I have a few questions.

    -Does attention sort net work better or at par with the simple sort net. -Does attention sort net work well for the sortcut variant also. -How does the sortcut variant perform when compared to vanilla sparse sinkhorn attention.

    It would be very helpful if you could share some plots/numbers from your experiments comparing the performance of the different variants such as attention sort net, routing based attention etc.

    opened by jsmith1915 7
  • Some questions about dropout

    Some questions about dropout

    Hi again @lucidrains, I just had some quick questions about dropout with the Sinkhorn Transformer, as I was just using my Linformer implementation (which as you know is based off of this repo), but it was overfitting my dataset. Therefore, I just had some quick questions about some dropout and your implementation, and I wanted to ask whether some design choices here were intentional or not:

    1. In the original Transformer, dropout was performed after each sublayer, before the residual connection. I noticed that you only have this after the SinkhornSelfAttention class, but not after the FeedForward class. Is this intentional?
    2. Speaking of the FeedForward class, you insert dropout after the first linear layer. I couldn't find this anywhere in any literature, were you able to find a reference of why this was effective? I put it into my implementation, and it seems to help, but i just don't know where this idea came from.
    3. On a similar note, do you know why the dots tensor in the self attention classes are dropped out? Again, I put it in my linformer and it seems to work, but I can't find a reference to this in the literature.
    4. Finally, the original transformer also dropped out the input tokens, like so (From the SinkhornTransformerLM class):
        def forward(self, x, **kwargs):
            _, t, device = *x.shape, x.device
            assert t <= self.max_seq_len, f'sequence length {t} is greater than maximum sequence length {self.max_seq_len}'
    
            x = self.to_token_emb(x)
            x = self.axial_pos_emb(x) + x
            """ Dropout would go here"""
            x = self.sinkhorn_transformer(x, **kwargs)
            return self.to_logits(x)
    

    Should they also be dropped out here as well?

    I now updated my repo such that all 4 of these dropout possibilities exist. I'll let you know if this helps overfitting.

    Thank you for your time!

    opened by tatp22 6
  • Training crashes due to inplace operation

    Training crashes due to inplace operation

    Hi there, I am trying to train this model on my custom dataset and training starts but after many iterations, training crashes due to this error:

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 128, 64]], which is output 0 of ViewBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
    

    After enabling anomaly detection, here is the error:

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 128, 64]], which is output 0 of ViewBackward, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
    

    Traceback is not useful as it only shows error at line loss.backward() and torch.autograd module. Can you help me where the issue might be?

    Thanks

    opened by NaxAlpha 6
  • Positional Embedding

    Positional Embedding

    Hi! Found out that SinkhornTransformerLM uses two positional encodings simultaneously:

    https://github.com/lucidrains/sinkhorn-transformer/blob/2b65e24085562a4e308251398007e2ca9b86d7cc/sinkhorn_transformer/sinkhorn_transformer.py#L778-L779

    https://github.com/lucidrains/sinkhorn-transformer/blob/2b65e24085562a4e308251398007e2ca9b86d7cc/sinkhorn_transformer/sinkhorn_transformer.py#L792-L793

    I guess pos_emb can be removed as it introduces memory overhead and makes useless the utilization of Axial Positional Encoding that is designed to reduce the number of positional encoding parameters.

    Is there a special reasoning behind that?

    opened by ilya16 4
  • A wrapper of SinkhornTransformerEncDec

    A wrapper of SinkhornTransformerEncDec

    Could your please coded up a wrapper that removes a lot of the manual work in writing up a generic SinkhornTransformer encoder / decoder architecture.

    Thanks a lot! halexan

    opened by halexan 0
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Official implementation of MLP Singer: Towards Rapid Parallel Korean Singing Voice Synthesis

MLP Singer Official implementation of MLP Singer: Towards Rapid Parallel Korean Singing Voice Synthesis. Audio samples are available on our demo page.

Neosapience 103 Dec 23, 2022
a test times augmentation toolkit based on paddle2.0.

Patta Image Test Time Augmentation with Paddle2.0! Input | # input batch of images / / /|\ \ \ # apply

AgentMaker 110 Dec 03, 2022
Incorporating KenLM language model with HuggingFace implementation of Wav2Vec2CTC Model using beam search decoding

Wav2Vec2CTC With KenLM Using KenLM ARPA language model with beam search to decode audio files and show the most probable transcription. Assuming you'v

farisalasmary 65 Sep 21, 2022
Geometry-Consistent Neural Shape Representation with Implicit Displacement Fields

Geometry-Consistent Neural Shape Representation with Implicit Displacement Fields [project page][paper][cite] Geometry-Consistent Neural Shape Represe

Yifan Wang 100 Dec 19, 2022
Framework for fine-tuning pretrained transformers for Named-Entity Recognition (NER) tasks

NERDA Not only is NERDA a mesmerizing muppet-like character. NERDA is also a python package, that offers a slick easy-to-use interface for fine-tuning

Ekstra Bladet 141 Dec 30, 2022
UA-GEC: Grammatical Error Correction and Fluency Corpus for the Ukrainian Language

UA-GEC: Grammatical Error Correction and Fluency Corpus for the Ukrainian Language This repository contains UA-GEC data and an accompanying Python lib

Grammarly 227 Jan 02, 2023
Phomber is infomation grathering tool that reverse search phone numbers and get their details, written in python3.

A Infomation Grathering tool that reverse search phone numbers and get their details ! What is phomber? Phomber is one of the best tools available fo

S41R4J 121 Dec 27, 2022
Chinese real time voice cloning (VC) and Chinese text to speech (TTS).

Chinese real time voice cloning (VC) and Chinese text to speech (TTS). 好用的中文语音克隆兼中文语音合成系统,包含语音编码器、语音合成器、声码器和可视化模块。

Kuang Dada 6 Nov 08, 2022
Sample data associated with the Aurora-BP study

The Aurora-BP Study and Dataset This repository contains sample code, sample data, and explanatory information for working with the Aurora-BP dataset

Microsoft 16 Dec 12, 2022
**NSFW** A chatbot based on GPT2-chitchat

DangBot -- 好怪哦,再来一句 卡群怪话bot,powered by GPT2 for Chinese chitchat Training Example: python train.py --lr 5e-2 --epochs 30 --max_len 300 --batch_size 8

Tommy Yang 11 Jul 21, 2022
A text augmentation tool for named entity recognition.

neraug This python library helps you with augmenting text data for named entity recognition. Augmentation Example Reference from An Analysis of Simple

Hiroki Nakayama 48 Oct 11, 2022
Codes for coreference-aware machine reading comprehension

Data and code for the paper "Tracing Origins: Coreference-aware Machine Reading Comprehension" at ACL2022. Dataset There are three folders for our thr

11 Sep 29, 2022
本插件是pcrjjc插件的重置版,可以独立于后端api运行

pcrjjc2 本插件是pcrjjc重置版,不需要使用其他后端api,但是需要自行配置客户端 本项目基于AGPL v3协议开源,由于项目特殊性,禁止基于本项目的任何商业行为 配置方法 环境需求:.net framework 4.5及以上 jre8 别忘了装jre8 别忘了装jre8 别忘了装jre8

132 Dec 26, 2022
Watson Natural Language Understanding and Knowledge Studio

Material de demonstração dos serviços: Watson Natural Language Understanding e Knowledge Studio Visão Geral: https://www.ibm.com/br-pt/cloud/watson-na

Vanderlei Munhoz 4 Oct 24, 2021
Resources for "Natural Language Processing" Coursera course.

Natural Language Processing course resources This github contains practical assignments for Natural Language Processing course by Higher School of Eco

Advanced Machine Learning specialisation by HSE 1.1k Jan 01, 2023
Chinese NewsTitle Generation Project by GPT2.带有超级详细注释的中文GPT2新闻标题生成项目。

GPT2-NewsTitle 带有超详细注释的GPT2新闻标题生成项目 UpDate 01.02.2021 从网上收集数据,将清华新闻数据、搜狗新闻数据等新闻数据集,以及开源的一些摘要数据进行整理清洗,构建一个较完善的中文摘要数据集。 数据集清洗时,仅进行了简单地规则清洗。

logCong 785 Dec 29, 2022
Transcribing audio files using Hugging Face's implementation of Wav2Vec2 + "chain-linking" NLP tasks to combine speech-to-text with downstream tasks like translation and summarisation.

PART 2: CHAIN LINKING AUDIO-TO-TEXT NLP TASKS 2A: TRANSCRIBE-TRANSLATE-SENTIMENT-ANALYSIS In notebook3.0, I demo a simple workflow to: transcribe a lo

Chua Chin Hon 30 Jul 13, 2022
Which Apple Keeps Which Doctor Away? Colorful Word Representations with Visual Oracles

Which Apple Keeps Which Doctor Away? Colorful Word Representations with Visual Oracles (TASLP 2022)

Zhuosheng Zhang 3 Apr 14, 2022
Applying "Load What You Need: Smaller Versions of Multilingual BERT" to LaBSE

smaller-LaBSE LaBSE(Language-agnostic BERT Sentence Embedding) is a very good method to get sentence embeddings across languages. But it is hard to fi

Jeong Ukjae 13 Sep 02, 2022
Semantic search through a vectorized Wikipedia (SentenceBERT) with the Weaviate vector search engine

Semantic search through Wikipedia with the Weaviate vector search engine Weaviate is an open source vector search engine with build-in vectorization a

SeMI Technologies 191 Dec 26, 2022