Paddle pit - Rethinking Spatial Dimensions of Vision Transformers

Overview

基于Paddle实现PiT ——Rethinking Spatial Dimensions of Vision Transformers,arxiv

  • 官方原版代码(基于PyTorch)pit.

  • 本项目基于 PaddleViT 实现,在其基础上与原版代码实现了更进一步的对齐,并通过完整训练与测试完成对pit_ti模型的复现.

1. 简介

从CNN的成功设计原理出发,作者研究了空间尺寸转换的作用及其在基于Transformer的体系结构上的有效性。

具体来说,类似于CNN的降维原则(随着深度的增加,传统的CNN会增加通道尺寸并减小空间尺寸),作者用实验表明了这同样有利于Transformer的性能提升,并提出了基于池化的Vision Transformer,即PiT(模型示意图如下)。

drawing

PiT 模型示意图

2. 数据集和复现精度

数据集

原文使用的为ImageNet-1k 2012(ILSVRC2012),共1000类,训练集/测试集图片分布:1281167/50000,数据集大小为144GB。

本项目使用的为官方推荐的图片压缩过的更轻量的Light_ILSVRC2012,数据集大小为65GB。其在AI Studio上的地址为:Light_ILSVRC2012_part_0.tarLight_ILSVRC2012_part_1.tar

复现精度

Model 目标精度[email protected] 实现精度[email protected] Image Size batch_size Crop_pct epoch #Params
pit_ti 73.0 73.01 224 256*4GPUs 0.9 300
(+10 COOLDOWN)
4.8M

【注】上表中的实现精度在原版ILSVRC2012验证集上测试得到。 值得一提的是,本项目在Light_ILSVRC2012的验证集上的Validation [email protected]达到了73.17

本项目训练得到的最佳模型参数与训练日志log均存放于output文件夹下。

日志文件说明

本项目通过AI Studio的脚本任务运行,中途中断了4次,因此共有5个日志文件。为了方便检阅,本人手动将log命名为log_开始epoch-结束epoch.txt格式。具体来说:

  • output/log_1-76.txt:epoch1~epoch76。这一版代码定义每10个epoch保存一次模型权重,每2个epoch验证一次,同时若验证精度高于历史精度,则保存为Best_PiT.pdparams,因此在epoch76训练结束但还未验证的时候中断,下一次的训练只能从验证精度最高的epoch74继续训练。

  • output/log_75-142.txt:epoch75~epoch142。从这一版代码开始,新增了每次训练之后都保存一下模型参数为PiT-Latest.pdparams,这样无论哪个epoch训练中断都可以继续训练啦。

  • output/log_143-225.txt:epoch143~epoch225。

  • output/log_226-303.txt:epoch226~epoch303。

  • output/log_304-310.txt:epoch304~epoch310。

  • output/log_eval.txt:使用训练得到的最好模型(epoch308)在原版ILSVRC2012验证集上验证日志。

3. 准备环境

推荐环境配置:

本人环境配置:

  • 硬件:Tesla V100 * 4(由衷感谢百度飞桨平台提供高性能算力支持)

  • PaddlePaddle==2.2.1

  • Python==3.7

4. 快速开始

本项目现已通过脚本任务形式部署到AI Studio上,您可以选择fork下来直接运行sh run.sh,数据集处理等脚本均已部署好。链接:paddle_pit

或者您也可以git本repo在本地运行:

第一步:克隆本项目

git clone https://github.com/hatimwen/paddle_pit.git
cd paddle_pit

第二步:修改参数

请根据实际情况,修改scripts路径下的脚本内容(如:gpu,数据集路径data_path,batch_size等)。

第三步:验证模型

多卡请运行:

sh scripts/run_eval_multi.sh

单卡请运行:

sh scripts/run_eval.sh

第四步:训练模型

多卡请运行:

sh scripts/run_train_multi.sh

单卡请运行:

sh scripts/run_train.sh

第五步:验证预测

python predict.py \
-pretrained='output/Best_PiT' \
-img_path='images/ILSVRC2012_val_00004506.JPEG'

验证图片(类别:藏獒, id: 244)

输出结果为:

class_id: 244, prob: 9.12291145324707

对照ImageNet类别id(ImageNet数据集编号对应的类别内容),可知244为藏獒,预测结果正确。

5.代码结构

|-- paddle_pit
    |-- output              # 日志及模型文件
    |-- configs             # 参数
        |-- pit_ti.yaml
    |-- datasets
        |-- ImageNet1K      # 数据集路径
    |-- scripts             # 运行脚本
        |-- run_train.sh
        |-- run_train_multi.sh
        |-- run_eval.sh
        |-- run_eval_multi.sh
    |-- augment.py          # 数据增强
    |-- config.py           # 最底层配置文件
    |-- datasets.py         # dataset与dataloader
    |-- droppath.py         # droppath定义
    |-- losses.py           # loss定义
    |-- main_multi_gpu.py   # 多卡训练测试代码
    |-- main_single_gpu.py  # 单卡训练测试代码
    |-- mixup.py            # mixup定义
    |-- model_ema.py        # EMA定义
    |-- pit.py              # pit模型结构定义
    |-- random_erasing.py   # random_erasing定义
    |-- regnet.py           # 教师模型定义(本项目并未对此验证,仅作保留)
    |-- transforms.py       # RandomHorizontalFlip定义
    |-- utils.py            # CosineLRScheduler及AverageMeter定义
    |-- README.md
    |-- requirements.txt

6. 参考及引用

@InProceedings{Yuan_2021_ICCV,
    author    = {Yuan, Li and Chen, Yunpeng and Wang, Tao and Yu, Weihao and Shi, Yujun and Jiang, Zi-Hang and Tay, Francis E.H. and Feng, Jiashi and Yan, Shuicheng},
    title     = {Tokens-to-Token ViT: Training Vision Transformers From Scratch on ImageNet},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {558-567}
}

最后,非常感谢百度举办的飞桨论文复现挑战赛(第五期)让本人对Paddle理解更加深刻。 同时也非常感谢朱欤老师团队用Paddle实现的PaddleViT,本项目大部分代码都是从中copy来的,而仅仅实现了其与原版代码训练步骤的进一步对齐与完整的训练过程,但本人也同样受益匪浅! ♥️

Contact

Owner
Hongtao Wen
Hongtao Wen
Enabling dynamic analysis of Legacy Embedded Systems in full emulated environment

PENecro This project is based on "Enabling dynamic analysis of Legacy Embedded Systems in full emulated environment", published on hardwear.io USA 202

Ta-Lun Yen 10 May 17, 2022
Auxiliary Raw Net (ARawNet) is a ASVSpoof detection model taking both raw waveform and handcrafted features as inputs, to balance the trade-off between performance and model complexity.

Overview This repository is an implementation of the Auxiliary Raw Net (ARawNet), which is ASVSpoof detection system taking both raw waveform and hand

6 Jul 08, 2022
PyTorch implementation of UNet++ (Nested U-Net).

PyTorch implementation of UNet++ (Nested U-Net) This repository contains code for a image segmentation model based on UNet++: A Nested U-Net Architect

4ui_iurz1 642 Jan 04, 2023
CVPR 2021

Smoothing the Disentangled Latent Style Space for Unsupervised Image-to-image Translation [Paper] | [Poster] | [Codes] Yahui Liu1,3, Enver Sangineto1,

Yahui Liu 37 Sep 12, 2022
Re-implementation of the Noise Contrastive Estimation algorithm for pyTorch, following "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models." (Gutmann and Hyvarinen, AISTATS 2010)

Noise Contrastive Estimation for pyTorch Overview This repository contains a re-implementation of the Noise Contrastive Estimation algorithm, implemen

Denis Emelin 42 Nov 24, 2022
This repository collects project-relevant Isabelle/HOL formalizations.

Isabelle/HOL formalizations related to the AuReLeE project Formalization of Abstract Argumentation Frameworks See AbstractArgumentation folder for the

AuReLeE project 1 Sep 10, 2022
Retinal vessel segmentation based on GT-UNet

Retinal vessel segmentation based on GT-UNet Introduction This project is a retinal blood vessel segmentation code based on UNet-like Group Transforme

Kent0n 27 Dec 18, 2022
Contrastive Learning Inverts the Data Generating Process

Official code to reproduce the results and data presented in the paper Contrastive Learning Inverts the Data Generating Process.

71 Nov 25, 2022
We present a regularized self-labeling approach to improve the generalization and robustness properties of fine-tuning.

Overview This repository provides the implementation for the paper "Improved Regularization and Robustness for Fine-tuning in Neural Networks", which

NEU-StatsML-Research 21 Sep 08, 2022
Mixed Transformer UNet for Medical Image Segmentation

MT-UNet Update 2021/11/19 Thank you for your interest in our work. We have uploaded the code of our MTUNet to help peers conduct further research on i

dotman 92 Dec 25, 2022
Hierarchical Attentive Recurrent Tracking

Hierarchical Attentive Recurrent Tracking This is an official Tensorflow implementation of single object tracking in videos by using hierarchical atte

Adam Kosiorek 147 Aug 07, 2021
Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation

Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation Requirements This repository needs mmsegmentation Training To train

20 May 28, 2022
Differentiable Simulation of Soft Multi-body Systems

Differentiable Simulation of Soft Multi-body Systems Yi-Ling Qiao, Junbang Liang, Vladlen Koltun, Ming C. Lin [Paper] [Code] Updates The C++ backend s

YilingQiao 26 Dec 23, 2022
Contra is a lightweight, production ready Tensorflow alternative for solving time series prediction challenges with AI

Contra AI Engine A lightweight, production ready Tensorflow alternative developed by Styvio styvio.com » How to Use · Report Bug · Request Feature Tab

styvio 14 May 25, 2022
Repository for "Exploring Sparsity in Image Super-Resolution for Efficient Inference", CVPR 2021

SMSR Reposity for "Exploring Sparsity in Image Super-Resolution for Efficient Inference" [arXiv] Highlights Locate and skip redundant computation in S

Longguang Wang 225 Dec 26, 2022
AEI: Actors-Environment Interaction with Adaptive Attention for Temporal Action Proposals Generation

AEI: Actors-Environment Interaction with Adaptive Attention for Temporal Action Proposals Generation A pytorch-version implementation codes of paper:

11 Dec 13, 2022
This repository contains project created during the Data Challenge module at London School of Hygiene & Tropical Medicine

LSHTM_RCS This repository contains project created during the Data Challenge module at London School of Hygiene & Tropical Medicine (LSHTM) in collabo

Lukas Kopecky 3 Jan 30, 2022
PyTorch code for the paper "FIERY: Future Instance Segmentation in Bird's-Eye view from Surround Monocular Cameras"

FIERY This is the PyTorch implementation for inference and training of the future prediction bird's-eye view network as described in: FIERY: Future In

Wayve 406 Dec 24, 2022
This is an example of object detection on Micro bacterium tuberculosis using Mask-RCNN

Mask-RCNN on Mycobacterium tuberculosis This is an example of object detection on Mycobacterium Tuberculosis using Mask RCNN. Implement of Mask R-CNN

Jun-En Ding 1 Sep 16, 2021
This repository contains pre-trained models and some evaluation code for our paper Towards Unsupervised Dense Information Retrieval with Contrastive Learning

Contriever: Towards Unsupervised Dense Information Retrieval with Contrastive Learning This repository contains pre-trained models and some evaluation

Meta Research 207 Jan 08, 2023