Aggragrating Nested Transformer Official Jax Implementation

Overview

Aggragrating Nested Transformer Official Jax Implementation

NesT is a simple method, which aggragrates nested local transformers on image blocks. The idea makes vision transformers attain better accuracy, data efficiency, and convergence on the ImageNet benchmark. NesT can be scaled to small datasets to match convnet accuracy.

This is not an officially supported Google product.

Pretrained Models and Results

Model Accuracy Checkpoint path
Nest-B 83.8 gs://gresearch/nest-checkpoints/nest-b_imagenet
Nest-S 83.3 gs://gresearch/nest-checkpoints/nest-s_imagenet
Nest-T 81.5 gs://gresearch/nest-checkpoints/nest-t_imagenet

Note: Accuracy is evaluated on the ImageNet2012 validation set.

Tensorbord.dev

See ImageNet training logs at Tensorboard.dev.

Colab

Colab is available for test: https://colab.sandbox.google.com/github/google-research/nested-transformer/blob/main/colab.ipynb

Instruction on Image Classification

Environment setup

virtualenv -p python3 --system-site-packages nestenv
source nestenv/bin/activate

pip install -r requirements.txt

Evaluate on ImageNet

At the first time, download ImageNet following tensorflow_datasets instruction from command lines. Optionally, download all pre-trained checkpoints

bash ./checkpoints/download_checkpoints.sh

Run the evaluation script to evaluate NesT-B.

python main.py --config configs/imagenet_nest.py --config.eval_only=True \
  --config.init_checkpoint="./checkpoints/nest-b_imagenet/ckpt.39" \
  --workdir="./checkpoints/nest-t_imagenet_eval"

Train on ImageNet

The default configuration trains NesT-B on TPUv2 8x8 with per device batch size 16.

python main.py --config configs/imagenet_nest.py --jax_backend_target=<TPU_IP_ADDRESS> --jax_xla_backend="tpu_driver" --workdir="./checkpoints/nest-b_imagenet"

Note: See jax/cloud_tpu_colab for info about TPU_IP_ADDRESS.

Train NesT-T on 8 GPUs.

python main.py --config configs/imagenet_nest_tiny.py --workdir="./checkpoints/nest-t_imagenet_8gpu"

The codebase does not support multi-node GPU training (>8 GPUs). The models reported in our paper is trained using TPU with 1024 total batch size.

Train on CIFAR

# Recommend to train on 2 GPUs. Training NesT-T can use 1 GPU.
CUDA_VISIBLE_DEVICES=0,1 python  main.py --config configs/cifar_nest.py --workdir="./checkpoints/nest_cifar"

Cite

@inproceedings{zhang2021aggregating,
  title={Aggregating Nested Transformers},
  author={Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
  booktitle={arXiv preprint arXiv:2105.12723},
  year={2021}
}
Owner
Google Research
Google Research
A project which aims to protect your privacy using inexpensive hardware and easily modifiable software

Protecting your privacy using an ESP32, an IR sensor and a python script This project, which I personally call the "never-gonna-catch-me-in-the-act-ev

8 Oct 10, 2022
Probabilistic Tensor Decomposition of Neural Population Spiking Activity

Probabilistic Tensor Decomposition of Neural Population Spiking Activity Matlab (recommended) and Python (in developement) implementations of Soulat e

Hugo Soulat 6 Nov 30, 2022
traiNNer is an open source image and video restoration (super-resolution, denoising, deblurring and others) and image to image translation toolbox based on PyTorch.

traiNNer traiNNer is an open source image and video restoration (super-resolution, denoising, deblurring and others) and image to image translation to

202 Jan 04, 2023
StyleTransfer - Open source style transfer project, based on VGG19

StyleTransfer - Open source style transfer project, based on VGG19

Patrick martins de lima 9 Dec 13, 2021
City Surfaces: City-scale Semantic Segmentation of Sidewalk Surfaces

City Surfaces: City-scale Semantic Segmentation of Sidewalk Surfaces Paper Temporary GitHub page for City Surfaces paper. More soon! While designing s

14 Nov 10, 2022
using yolox+deepsort for object-tracker

YOLOX_deepsort_tracker yolox+deepsort实现目标跟踪 最新的yolox尝尝鲜~~(yolox正处在频繁更新阶段,因此直接链接yolox仓库作为子模块) Install Clone the repository recursively: git clone --rec

245 Dec 26, 2022
PyTorch implementation of paper "Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes", CVPR 2021

Neural Scene Flow Fields PyTorch implementation of paper "Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes", CVPR 20

Zhengqi Li 585 Jan 04, 2023
Assessing the Influence of Models on the Performance of Reinforcement Learning Algorithms applied on Continuous Control Tasks

Assessing the Influence of Models on the Performance of Reinforcement Learning Algorithms applied on Continuous Control Tasks This is the master thesi

Giacomo Arcieri 1 Mar 21, 2022
A series of Jupyter notebooks with Chinese comment that walk you through the fundamentals of Machine Learning and Deep Learning in python using Scikit-Learn and TensorFlow.

Hands-on-Machine-Learning 目的 这份笔记旨在帮助中文学习者以一种较快较系统的方式入门机器学习, 是在学习Hands-on Machine Learning with Scikit-Learn and TensorFlow这本书的 时候做的个人笔记: 此项目的可取之处 原书的

Baymax 1.5k Dec 21, 2022
The object detection pipeline is based on Ultralytics YOLOv5

AYOLOv2 The main goal of this repository is to rewrite the object detection pipeline with a better code structure for better portability and adaptabil

153 Dec 22, 2022
The King is Naked: on the Notion of Robustness for Natural Language Processing

the-king-is-naked: on the notion of robustness for natural language processing AAAI2022 DISCLAIMER:This repo will be updated soon with instructions on

Iperboreo_ 1 Nov 24, 2022
Neural network graphs and training metrics for PyTorch, Tensorflow, and Keras.

HiddenLayer A lightweight library for neural network graphs and training metrics for PyTorch, Tensorflow, and Keras. HiddenLayer is simple, easy to ex

Waleed 1.7k Dec 31, 2022
Dynamic Head: Unifying Object Detection Heads with Attentions

Dynamic Head: Unifying Object Detection Heads with Attentions dyhead_video.mp4 This is the official implementation of CVPR 2021 paper "Dynamic Head: U

Microsoft 550 Dec 21, 2022
Elevation Mapping on GPU.

Elevation Mapping cupy Overview This is a ros package of elevation mapping on GPU. Code are written in python and uses cupy for GPU calculation. * pla

Robotic Systems Lab - Legged Robotics at ETH Zürich 183 Dec 19, 2022
Share a benchmark that can easily apply reinforcement learning in Job-shop-scheduling

Gymjsp Gymjsp is an open source Python library, which uses the OpenAI Gym interface for easily instantiating and interacting with RL environments, and

134 Dec 08, 2022
Simple renderer for use with MuJoCo (>=2.1.2) Python Bindings.

Viewer for MuJoCo in Python Interactive renderer to use with the official Python bindings for MuJoCo. Starting with version 2.1.2, MuJoCo comes with n

Rohan P. Singh 62 Dec 30, 2022
The (Official) PyTorch Implementation of the paper "Deep Extraction of Manga Structural Lines"

MangaLineExtraction_PyTorch The (Official) PyTorch Implementation of the paper "Deep Extraction of Manga Structural Lines" Usage model_torch.py [sourc

Miaomiao Li 82 Jan 02, 2023
[CVPR 2020] Local Class-Specific and Global Image-Level Generative Adversarial Networks for Semantic-Guided Scene Generation

Contents Local and Global GAN Cross-View Image Translation Semantic Image Synthesis Acknowledgments Related Projects Citation Contributions Collaborat

Hao Tang 131 Dec 07, 2022
[WWW 2022] Zero-Shot Stance Detection via Contrastive Learning

PT-HCL for Zero-Shot Stance Detection The code of this repository is constantly being updated... Please look forward to it! Introduction This reposito

Akuchi 12 Dec 21, 2022