Reproduction of Vision Transformer in Tensorflow2. Train from scratch and Finetune.

Overview

Vision Transformer(ViT) in Tensorflow2

Tensorflow2 implementation of the Vision Transformer(ViT).

This repository is for An image is worth 16x16 words: Transformers for image recognition at scale and How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers.

Limitations.

  • Due to memory limitations, only the ti/16, s/16, and b/16 models were tested.
  • Due to memory limitations, batch_size 2048 in s16 and 1024 in b/16 (in paper, 4096).
  • Due to computational resource limitations, only reproduce using imagenet1k.

All experimental results and graphs are opend in Wandb.

Model weights

Since this is personal project, it is hard to train with large datasets like imagenet21k. For a pretrain model with good performance, see the official repo. But if you really need it, contact me.

Install dependencies

pip install -r requirements

All experiments were done on tpu_v3-8 with the support of TRC. But you can experiment on GPU. Check conf/config.yaml and conf/downstream.yaml

  # TPU options
  env:
    mode: tpu
    gcp_project: {your_project}
    tpu_name: node-1
    tpu_zone: europe-west4-a
    mixed_precision: True
  # GPU options
  # env:
  #   mode: gpu
  #   mixed_precision: True

Train from scratch

python run.py experiment=vit-s16-aug_light1-bs_2048-wd_0.1-do_0.1-dp_0.1-lr_1e-3 base.project_name=vit-s16-aug_light1-bs_2048-wd_0.1-do_0.1-dp_0.1-lr_1e-3 base.save_dir={your_save_dir} base.env.gcp_project={your_gcp_project} base.env.tpu_name={your_tpu_name} base.debug=False

Downstream

python run.py --config-name=downstream experiment=downstream-imagenet-ti16_384 base.pretrained={your_checkpoint} base.project_name={your_project_name} base.save_dir={your_save_dir} base.env.gcp_project={your_gcp_project} base.env.tpu_name={your_tpu_name} base.debug=False

Board

To track metics, you can use wandb or tensorboard (default: wandb). You can change in conf/callbacks/{filename.yaml}.

modules:
  - type: MonitorCallback
  - type: TerminateOnNaN
  - type: ProgbarLogger
    params:
      count_mode: steps
  - type: ModelCheckpoint
    params:
      filepath: ???
      save_weights_only: True
  - type: Wandb
    project: vit
    nested_dict: False
    hide_config: True
    params: 
      monitor: val_loss
      save_model: False
  # - type: TensorBoard
  #   params:
  #     log_dir: ???
  #     histogram_freq: 1

TFC

This open source was assisted by TPU Research Cloud (TRC) program

Thank you for providing the TPU.

Citations

@article{dosovitskiy2020image,
  title={An image is worth 16x16 words: Transformers for image recognition at scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others},
  journal={arXiv preprint arXiv:2010.11929},
  year={2020}
}
@article{steiner2021train,
  title={How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers},
  author={Steiner, Andreas and Kolesnikov, Alexander and Zhai, Xiaohua and Wightman, Ross and Uszkoreit, Jakob and Beyer, Lucas},
  journal={arXiv preprint arXiv:2106.10270},
  year={2021}
}
Owner
sungjun lee
AI Researcher
sungjun lee
Visualizer for neural network, deep learning, and machine learning models

Netron is a viewer for neural network, deep learning and machine learning models. Netron supports ONNX (.onnx, .pb, .pbtxt), Keras (.h5, .keras), Tens

Lutz Roeder 21k Jan 06, 2023
VLG-Net: Video-Language Graph Matching Networks for Video Grounding

VLG-Net: Video-Language Graph Matching Networks for Video Grounding Introduction Official repository for VLG-Net: Video-Language Graph Matching Networ

Mattia Soldan 25 Dec 04, 2022
In real-world applications of machine learning, reliable and safe systems must consider measures of performance beyond standard test set accuracy

PixMix Introduction In real-world applications of machine learning, reliable and safe systems must consider measures of performance beyond standard te

Andy Zou 79 Dec 30, 2022
NeuralForecast is a Python library for time series forecasting with deep learning models

NeuralForecast is a Python library for time series forecasting with deep learning models. It includes benchmark datasets, data-loading utilities, evaluation functions, statistical tests, univariate m

Nixtla 1.1k Jan 03, 2023
[NeurIPS 2021] Low-Rank Subspaces in GANs

Low-Rank Subspaces in GANs Figure: Image editing results using LowRankGAN on StyleGAN2 (first three columns) and BigGAN (last column). Low-Rank Subspa

112 Dec 28, 2022
Pre-trained Deep Learning models and demos (high quality and extremely fast)

OpenVINO™ Toolkit - Open Model Zoo repository This repository includes optimized deep learning models and a set of demos to expedite development of hi

OpenVINO Toolkit 3.4k Dec 31, 2022
some classic model used to segment the medical images like CT、X-ray and so on

github_project This is a project for medical image segmentation. This project includes common medical image segmentation models such as U-net, FCN, De

2 Mar 30, 2022
A lightweight face-recognition toolbox and pipeline based on tensorflow-lite

FaceIDLight 📘 Description A lightweight face-recognition toolbox and pipeline based on tensorflow-lite with MTCNN-Face-Detection and ArcFace-Face-Rec

Martin Knoche 16 Dec 07, 2022
This is the official source code for SLATE. We provide the code for the model, the training code, and a dataset loader for the 3D Shapes dataset. This code is implemented in Pytorch.

SLATE This is the official source code for SLATE. We provide the code for the model, the training code and a dataset loader for the 3D Shapes dataset.

Gautam Singh 66 Dec 26, 2022
This repository is a series of notebooks that show solutions for the projects at Dataquest.io.

Dataquest Project Solutions This repository is a series of notebooks that show solutions for the projects at Dataquest.io. Of course, there are always

Dataquest 1.1k Dec 30, 2022
OHLC Average Prediction of Apple Inc. Using LSTM Recurrent Neural Network

Stock Price Prediction of Apple Inc. Using Recurrent Neural Network OHLC Average Prediction of Apple Inc. Using LSTM Recurrent Neural Network Dataset:

Nouroz Rahman 410 Jan 05, 2023
Dynamic Graph Event Detection

DyGED Dynamic Graph Event Detection Get Started pip install -r requirements.txt TODO Paper link to arxiv, and how to cite. Twitter Weather dataset tra

Mert Koşan 3 May 09, 2022
pytorch implementation of openpose including Hand and Body Pose Estimation.

pytorch-openpose pytorch implementation of openpose including Body and Hand Pose Estimation, and the pytorch model is directly converted from openpose

Hzzone 1.4k Jan 07, 2023
InferPy: Deep Probabilistic Modeling with Tensorflow Made Easy

InferPy: Deep Probabilistic Modeling Made Easy InferPy is a high-level API for probabilistic modeling written in Python and capable of running on top

PGM-Lab 141 Oct 13, 2022
Self-Supervised Image Denoising via Iterative Data Refinement

Self-Supervised Image Denoising via Iterative Data Refinement Yi Zhang1, Dasong Li1, Ka Lung Law2, Xiaogang Wang1, Hongwei Qin2, Hongsheng Li1 1CUHK-S

Zhang Yi 72 Jan 01, 2023
Code for CVPR 2021 paper TransNAS-Bench-101: Improving Transferrability and Generalizability of Cross-Task Neural Architecture Search.

TransNAS-Bench-101 This repository contains the publishable code for CVPR 2021 paper TransNAS-Bench-101: Improving Transferrability and Generalizabili

Yawen Duan 17 Nov 20, 2022
Enabling dynamic analysis of Legacy Embedded Systems in full emulated environment

PENecro This project is based on "Enabling dynamic analysis of Legacy Embedded Systems in full emulated environment", published on hardwear.io USA 202

Ta-Lun Yen 10 May 17, 2022
sequitur is a library that lets you create and train an autoencoder for sequential data in just two lines of code

sequitur sequitur is a library that lets you create and train an autoencoder for sequential data in just two lines of code. It implements three differ

Jonathan Shobrook 305 Dec 21, 2022
1st Solution For NeurIPS 2021 Competition on ML4CO Dual Task

KIDA: Knowledge Inheritance in Data Aggregation This project releases our 1st place solution on NeurIPS2021 ML4CO Dual Task. Slide and model weights a

MEGVII Research 24 Sep 08, 2022
novel deep learning research works with PaddlePaddle

Research 发布基于飞桨的前沿研究工作,包括CV、NLP、KG、STDM等领域的顶会论文和比赛冠军模型。 目录 计算机视觉(Computer Vision) 自然语言处理(Natrual Language Processing) 知识图谱(Knowledge Graph) 时空数据挖掘(Spa

1.5k Dec 29, 2022