Implementation of Bottleneck Transformer in Pytorch

Overview

Bottleneck Transformer - Pytorch

PyPI version

Implementation of Bottleneck Transformer, SotA visual recognition model with convolution + attention that outperforms EfficientNet and DeiT in terms of performance-computes trade-off, in Pytorch

Install

$ pip install bottleneck-transformer-pytorch

Usage

import torch
from torch import nn
from bottleneck_transformer_pytorch import BottleStack

layer = BottleStack(
    dim = 256,              # channels in
    fmap_size = 64,         # feature map size
    dim_out = 2048,         # channels out
    proj_factor = 4,        # projection factor
    downsample = True,      # downsample on first layer or not
    heads = 4,              # number of heads
    dim_head = 128,         # dimension per head, defaults to 128
    rel_pos_emb = False,    # use relative positional embedding - uses absolute if False
    activation = nn.ReLU()  # activation throughout the network
)

fmap = torch.randn(2, 256, 64, 64) # feature map from previous resnet block(s)

layer(fmap) # (2, 2048, 32, 32)

BotNet

With some simple model surgery off a resnet, you can have the 'BotNet' (what a weird name) for training.

import torch
from torch import nn
from torchvision.models import resnet50

from bottleneck_transformer_pytorch import BottleStack

layer = BottleStack(
    dim = 256,
    fmap_size = 56,        # set specifically for imagenet's 224 x 224
    dim_out = 2048,
    proj_factor = 4,
    downsample = True,
    heads = 4,
    dim_head = 128,
    rel_pos_emb = True,
    activation = nn.ReLU()
)

resnet = resnet50()

# model surgery

backbone = list(resnet.children())

model = nn.Sequential(
    *backbone[:5],
    layer,
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(1),
    nn.Linear(2048, 1000)
)

# use the 'BotNet'

img = torch.randn(2, 3, 224, 224)
preds = model(img) # (2, 1000)

Citations

@misc{srinivas2021bottleneck,
    title   = {Bottleneck Transformers for Visual Recognition}, 
    author  = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani},
    year    = {2021},
    eprint  = {2101.11605},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • How should I modify the code if the input feature map has unequal width and height?

    How should I modify the code if the input feature map has unequal width and height?

    Assume that the width and height of the feature map are 10 and 8, respectively. Could you please help me to check that the modification about the class RelPosEmb is correct?

    class RelPosEmb(nn.Module): def init( self, fmap_size, dim_head ): super().init() scale = dim_head ** -0.5 self.fmap_size = fmap_size self.scale = scale # self.rel_height = nn.Parameter(torch.randn(fmap_size * 2 - 1, dim_head) * scale) self.rel_height = nn.Parameter(torch.randn(8 * 2 - 1, dim_head) * scale) # self.rel_width = nn.Parameter(torch.randn(fmap_size * 2 - 1, dim_head) * scale) self.rel_width = nn.Parameter(torch.randn(10* 2 - 1, dim_head) * scale)

    def forward(self, q):
        q = rearrange(q, 'b h (x y) d -> b h x y d', x = 8)
        rel_logits_w = relative_logits_1d(q, self.rel_width)
        rel_logits_w = rearrange(rel_logits_w, 'b h x i y j-> b h (x y) (i j)')
    
        q = rearrange(q, 'b h x y d -> b h y x d')
        rel_logits_h = relative_logits_1d(q, self.rel_height)
        rel_logits_h = rearrange(rel_logits_h, 'b h x i y j -> b h (y x) (j i)')
        return rel_logits_w + rel_logits_h
    
    opened by ShuweiShao 4
  • Feature map size

    Feature map size

    hi In my case, the input images size are all different, so the feature map size keeps changing. In this case, how should the fmap_size parameter of BottleStack be set? Is it possible to learn with an unfixed feature map size?

    opened by benlee73 3
  • A little bug.

    A little bug.

    https://github.com/lucidrains/bottleneck-transformer-pytorch/blob/b789de6db39f33854862fbc9bcee27c697cf003c/bottleneck_transformer_pytorch/bottleneck_transformer_pytorch.py#L16

    It is necessary to specify the equipment here.

    flat_pad = torch.zeros((b, h, l - 1), device = device, dtype = dtype) 
    
    opened by lartpang 1
  • fix inplace operations

    fix inplace operations

    Latest versions of PyTorch throw runtime errors for inplace operations like *= and += on tensors that require gradients. This pull request fixes the issue by replacing them with binary versions.

    opened by AminRezaei0x443 0
  • could you explain the implements of ralative position embedding?

    could you explain the implements of ralative position embedding?

    reference https://github.com/tensorflow/tensor2tensor/blob/5f9dd2db6d7797162e53adf152310ed13e9fc711/tensor2tensor/layers/common_attention.py

    def _generate_relative_positions_matrix(length_q, length_k,
                                            max_relative_position,
                                            cache=False):
      """Generates matrix of relative positions between inputs."""
      if not cache:
        if length_q == length_k:
          range_vec_q = range_vec_k = tf.range(length_q)
        else:
          range_vec_k = tf.range(length_k)
          range_vec_q = range_vec_k[-length_q:]
        distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
      else:
        distance_mat = tf.expand_dims(tf.range(-length_k+1, 1, 1), 0)
      distance_mat_clipped = tf.clip_by_value(distance_mat, -max_relative_position,
                                              max_relative_position)
      # Shift values to be >= 0. Each integer still uniquely identifies a relative
      # position difference.
      final_mat = distance_mat_clipped + max_relative_position
      return final_mat
    
    
    def _generate_relative_positions_embeddings(length_q, length_k, depth,
                                                max_relative_position, name,
                                                cache=False):
      """Generates tensor of size [1 if cache else length_q, length_k, depth]."""
      with tf.variable_scope(name):
        relative_positions_matrix = _generate_relative_positions_matrix(
            length_q, length_k, max_relative_position, cache=cache)
        vocab_size = max_relative_position * 2 + 1
        # Generates embedding for each relative position of dimension depth.
        embeddings_table = tf.get_variable("embeddings", [vocab_size, depth])
        embeddings = tf.gather(embeddings_table, relative_positions_matrix)
        return embeddings
    
    
    def _relative_attention_inner(x, y, z, transpose):
      """Relative position-aware dot-product attention inner calculation.
      This batches matrix multiply calculations to avoid unnecessary broadcasting.
      Args:
        x: Tensor with shape [batch_size, heads, length or 1, length or depth].
        y: Tensor with shape [batch_size, heads, length or 1, depth].
        z: Tensor with shape [length or 1, length, depth].
        transpose: Whether to transpose inner matrices of y and z. Should be true if
            last dimension of x is depth, not length.
      Returns:
        A Tensor with shape [batch_size, heads, length, length or depth].
      """
      batch_size = tf.shape(x)[0]
      heads = x.get_shape().as_list()[1]
      length = tf.shape(x)[2]
    
      # xy_matmul is [batch_size, heads, length or 1, length or depth]
      xy_matmul = tf.matmul(x, y, transpose_b=transpose)
      # x_t is [length or 1, batch_size, heads, length or depth]
      x_t = tf.transpose(x, [2, 0, 1, 3])
      # x_t_r is [length or 1, batch_size * heads, length or depth]
      x_t_r = tf.reshape(x_t, [length, heads * batch_size, -1])
      # x_tz_matmul is [length or 1, batch_size * heads, length or depth]
      x_tz_matmul = tf.matmul(x_t_r, z, transpose_b=transpose)
      # x_tz_matmul_r is [length or 1, batch_size, heads, length or depth]
      x_tz_matmul_r = tf.reshape(x_tz_matmul, [length, batch_size, heads, -1])
      # x_tz_matmul_r_t is [batch_size, heads, length or 1, length or depth]
      x_tz_matmul_r_t = tf.transpose(x_tz_matmul_r, [1, 2, 0, 3])
      return xy_matmul + x_tz_matmul_r_t
    
    
    def dot_product_attention_relative(q,
                                       k,
                                       v,
                                       bias,
                                       max_relative_position,
                                       dropout_rate=0.0,
                                       image_shapes=None,
                                       save_weights_to=None,
                                       name=None,
                                       make_image_summary=True,
                                       cache=False,
                                       allow_memory=False,
                                       hard_attention_k=0,
                                       gumbel_noise_weight=0.0):
      """Calculate relative position-aware dot-product self-attention.
      The attention calculation is augmented with learned representations for the
      relative position between each element in q and each element in k and v.
      Args:
        q: a Tensor with shape [batch, heads, length, depth].
        k: a Tensor with shape [batch, heads, length, depth].
        v: a Tensor with shape [batch, heads, length, depth].
        bias: bias Tensor.
        max_relative_position: an integer specifying the maximum distance between
            inputs that unique position embeddings should be learned for.
        dropout_rate: a floating point number.
        image_shapes: optional tuple of integer scalars.
        save_weights_to: an optional dictionary to capture attention weights
          for visualization; the weights tensor will be appended there under
          a string key created from the variable scope (including name).
        name: an optional string.
        make_image_summary: Whether to make an attention image summary.
        cache: whether use cache mode
        allow_memory: whether to assume that recurrent memory is in use. If True,
          the length dimension of k/v/bias may be longer than the queries, and it is
          assumed that the extra memory entries precede the non-memory entries.
        hard_attention_k: integer, if > 0 triggers hard attention (picking top-k)
        gumbel_noise_weight: if > 0, apply Gumbel noise with weight
          `gumbel_noise_weight` before picking top-k. This is a no op if
          hard_attention_k <= 0.
      Returns:
        A Tensor.
      Raises:
        ValueError: if max_relative_position is not > 0.
      """
      if not max_relative_position:
        raise ValueError("Max relative position (%s) should be > 0 when using "
                         "relative self attention." % (max_relative_position))
      with tf.variable_scope(
          name, default_name="dot_product_attention_relative",
          values=[q, k, v]) as scope:
    
        # This calculation only works for self attention.
        # q, k and v must therefore have the same shape, unless memory is enabled.
        if not cache and not allow_memory:
          q.get_shape().assert_is_compatible_with(k.get_shape())
          q.get_shape().assert_is_compatible_with(v.get_shape())
    
        # Use separate embeddings suitable for keys and values.
        depth = k.get_shape().as_list()[3]
        length_k = common_layers.shape_list(k)[2]
        length_q = common_layers.shape_list(q)[2] if allow_memory else length_k
        relations_keys = _generate_relative_positions_embeddings(
            length_q, length_k, depth, max_relative_position,
            "relative_positions_keys", cache=cache)
        relations_values = _generate_relative_positions_embeddings(
            length_q, length_k, depth, max_relative_position,
            "relative_positions_values", cache=cache)
    
        # Compute self attention considering the relative position embeddings.
        logits = _relative_attention_inner(q, k, relations_keys, True)
        if bias is not None:
          logits += bias
        weights = tf.nn.softmax(logits, name="attention_weights")
        if hard_attention_k > 0:
          weights = harden_attention_weights(weights, hard_attention_k,
                                             gumbel_noise_weight)
        if save_weights_to is not None:
          save_weights_to[scope.name] = weights
          save_weights_to[scope.name + "/logits"] = logits
        weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
        if (not tf.get_variable_scope().reuse and
            common_layers.should_generate_summaries() and
            make_image_summary):
          attention_image_summary(weights, image_shapes)
        return _relative_attention_inner(weights, v, relations_values, False)
    

    which is coresponding of the formula clip(x; k) = max(-k; min(k; x))

    but in youre code ,there is a randn with grad,i don't understand ,could you make a explanations?

    opened by AncientRemember 0
  • Is it possible to modify these codes to support 3D images as well?

    Is it possible to modify these codes to support 3D images as well?

    Thank you for your great work!

    I was wondering if it is possible to modify these codes to support 3D images as well (i.e. adding z-axis). image

    I can't imagine how to change the dimensions of vectors in the "content-position" part. E.g. Hx1xd and 1xWxd -> Hx1x1xd and 1x1xZxd and 1xWx1xd ?

    Thank you for your answer!

    opened by kyuchoi 0
  • Hello, in the training of the following mistakes, how to solve it

    Hello, in the training of the following mistakes, how to solve it

    einops.EinopsError: Error while processing rearrange-reduction pattern "b h (x y) d -> b h x y d". Input tensor shape: torch.Size([2, 4, 900, 128]). Additional info: {'x': 26, 'y': 26}. Shape mismatch, 900 != 676

    opened by glt999 1
  • the size of tenor a (9) must match the size of tenor b (10) at a non singleton dimension 3

    the size of tenor a (9) must match the size of tenor b (10) at a non singleton dimension 3

    Hello, I want to ask a question, the input feature map is 228 * 304, but here is an error, the size of tenor a (9) must match the size of tenor B (10) at a non singleton dimension 3.

    opened by shezhi 1
  • the 2d relative position embedding is not Inductive,maybe the FLOATER embedding is better

    the 2d relative position embedding is not Inductive,maybe the FLOATER embedding is better

    opened by AncientRemember 2
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution.

WSRGlow The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution. Audio sa

Kexun Zhang 96 Jan 03, 2023
Pytorch implementation of the DeepDream computer vision algorithm

deep-dream-in-pytorch Pytorch (https://github.com/pytorch/pytorch) implementation of the deep dream (https://en.wikipedia.org/wiki/DeepDream) computer

102 Dec 05, 2022
Some simple programs built in Python: webcam with cv2 that detects eyes and face, with grayscale filter

Programas en Python Algunos programas simples creados en Python: 📹 Webcam con c

Madirex 1 Feb 15, 2022
Diverse graph algorithms implemented using JGraphT library.

# 1. Installing Maven & Pandas First, please install Java (JDK11) and Python 3 if they are not already. Next, make sure that Maven (for importing J

See Woo Lee 3 Dec 17, 2022
HAR-stacked-residual-bidir-LSTMs - Deep stacked residual bidirectional LSTMs for HAR

HAR-stacked-residual-bidir-LSTM The project is based on this repository which is presented as a tutorial. It consists of Human Activity Recognition (H

Guillaume Chevalier 287 Dec 27, 2022
A Light CNN for Deep Face Representation with Noisy Labels

A Light CNN for Deep Face Representation with Noisy Labels Citation If you use our models, please cite the following paper: @article{wulight, title=

Alfred Xiang Wu 715 Nov 05, 2022
Jittor 64*64 implementation of StyleGAN

StyleGanJittor (Tsinghua university computer graphics course) Overview Jittor 64

Song Shengyu 3 Jan 20, 2022
Fast, flexible and fun neural networks.

Brainstorm Discontinuation Notice Brainstorm is no longer being maintained, so we recommend using one of the many other,available frameworks, such as

IDSIA 1.3k Nov 21, 2022
RAANet: Range-Aware Attention Network for LiDAR-based 3D Object Detection with Auxiliary Density Level Estimation

RAANet: Range-Aware Attention Network for LiDAR-based 3D Object Detection with Auxiliary Density Level Estimation Anonymous submission Abstract 3D obj

30 Sep 16, 2022
Face Identity Disentanglement via Latent Space Mapping [SIGGRAPH ASIA 2020]

Face Identity Disentanglement via Latent Space Mapping Description Official Implementation of the paper Face Identity Disentanglement via Latent Space

150 Dec 07, 2022
Code for "Layered Neural Rendering for Retiming People in Video."

Layered Neural Rendering in PyTorch This repository contains training code for the examples in the SIGGRAPH Asia 2020 paper "Layered Neural Rendering

Google 154 Dec 16, 2022
MRI reconstruction (e.g., QSM) using deep learning methods

deepMRI: Deep learning methods for MRI Authors: Yang Gao, Hongfu Sun This repo is devloped based on Pytorch (1.8 or later) and matlab (R2019a or later

Hongfu Sun 17 Dec 18, 2022
Tool which allow you to detect and translate text.

Text detection and recognition This repository contains tool which allow to detect region with text and translate it one by one. Description Two pretr

Damian Panek 176 Nov 28, 2022
Video Representation Learning by Recognizing Temporal Transformations. In ECCV, 2020.

Video Representation Learning by Recognizing Temporal Transformations [Project Page] Simon Jenni, Givi Meishvili, and Paolo Favaro. In ECCV, 2020. Thi

Simon Jenni 46 Nov 14, 2022
[CVPR 2021] Released code for Counterfactual Zero-Shot and Open-Set Visual Recognition

Counterfactual Zero-Shot and Open-Set Visual Recognition This project provides implementations for our CVPR 2021 paper Counterfactual Zero-S

144 Dec 24, 2022
Json2Xml tool will help you convert from json COCO format to VOC xml format in Object Detection Problem.

JSON 2 XML All codes assume running from root directory. Please update the sys path at the beginning of the codes before running. Over View Json2Xml t

Nguyễn Trường Lâu 6 Aug 22, 2022
All supplementary material used by me while TA-ing CS3244: Machine Learning

CS3244-Tutorial-Material All supplementary material used by me while TA-ing CS3244: Machine Learning at NUS School of Computing. What is this? I teach

Rishabh Anand 18 Sep 23, 2022
Small utility to demangle Nim symbols in callgrind files

nim_callgrind A small utility to demangle Nim symbols from callgrind files. Usage Run your (Nim) program with something like this: valgrind --tool=cal

kraptor 3 Feb 15, 2022
E2e music remastering system - End-to-end Music Remastering System Using Self-supervised and Adversarial Training

End-to-end Music Remastering System This repository includes source code and pre

Junghyun (Tony) Koo 37 Dec 15, 2022
Moving Object Segmentation in 3D LiDAR Data: A Learning-based Approach Exploiting Sequential Data

LiDAR-MOS: Moving Object Segmentation in 3D LiDAR Data This repo contains the code for our paper: Moving Object Segmentation in 3D LiDAR Data: A Learn

Photogrammetry & Robotics Bonn 394 Dec 29, 2022