Paddle pit - Rethinking Spatial Dimensions of Vision Transformers

Overview

基于Paddle实现PiT ——Rethinking Spatial Dimensions of Vision Transformers,arxiv

  • 官方原版代码(基于PyTorch)pit.

  • 本项目基于 PaddleViT 实现,在其基础上与原版代码实现了更进一步的对齐,并通过完整训练与测试完成对pit_ti模型的复现.

1. 简介

从CNN的成功设计原理出发,作者研究了空间尺寸转换的作用及其在基于Transformer的体系结构上的有效性。

具体来说,类似于CNN的降维原则(随着深度的增加,传统的CNN会增加通道尺寸并减小空间尺寸),作者用实验表明了这同样有利于Transformer的性能提升,并提出了基于池化的Vision Transformer,即PiT(模型示意图如下)。

drawing

PiT 模型示意图

2. 数据集和复现精度

数据集

原文使用的为ImageNet-1k 2012(ILSVRC2012),共1000类,训练集/测试集图片分布:1281167/50000,数据集大小为144GB。

本项目使用的为官方推荐的图片压缩过的更轻量的Light_ILSVRC2012,数据集大小为65GB。其在AI Studio上的地址为:Light_ILSVRC2012_part_0.tarLight_ILSVRC2012_part_1.tar

复现精度

Model 目标精度[email protected] 实现精度[email protected] Image Size batch_size Crop_pct epoch #Params
pit_ti 73.0 73.01 224 256*4GPUs 0.9 300
(+10 COOLDOWN)
4.8M

【注】上表中的实现精度在原版ILSVRC2012验证集上测试得到。 值得一提的是,本项目在Light_ILSVRC2012的验证集上的Validation [email protected]达到了73.17

本项目训练得到的最佳模型参数与训练日志log均存放于output文件夹下。

日志文件说明

本项目通过AI Studio的脚本任务运行,中途中断了4次,因此共有5个日志文件。为了方便检阅,本人手动将log命名为log_开始epoch-结束epoch.txt格式。具体来说:

  • output/log_1-76.txt:epoch1~epoch76。这一版代码定义每10个epoch保存一次模型权重,每2个epoch验证一次,同时若验证精度高于历史精度,则保存为Best_PiT.pdparams,因此在epoch76训练结束但还未验证的时候中断,下一次的训练只能从验证精度最高的epoch74继续训练。

  • output/log_75-142.txt:epoch75~epoch142。从这一版代码开始,新增了每次训练之后都保存一下模型参数为PiT-Latest.pdparams,这样无论哪个epoch训练中断都可以继续训练啦。

  • output/log_143-225.txt:epoch143~epoch225。

  • output/log_226-303.txt:epoch226~epoch303。

  • output/log_304-310.txt:epoch304~epoch310。

  • output/log_eval.txt:使用训练得到的最好模型(epoch308)在原版ILSVRC2012验证集上验证日志。

3. 准备环境

推荐环境配置:

本人环境配置:

  • 硬件:Tesla V100 * 4(由衷感谢百度飞桨平台提供高性能算力支持)

  • PaddlePaddle==2.2.1

  • Python==3.7

4. 快速开始

本项目现已通过脚本任务形式部署到AI Studio上,您可以选择fork下来直接运行sh run.sh,数据集处理等脚本均已部署好。链接:paddle_pit

或者您也可以git本repo在本地运行:

第一步:克隆本项目

git clone https://github.com/hatimwen/paddle_pit.git
cd paddle_pit

第二步:修改参数

请根据实际情况,修改scripts路径下的脚本内容(如:gpu,数据集路径data_path,batch_size等)。

第三步:验证模型

多卡请运行:

sh scripts/run_eval_multi.sh

单卡请运行:

sh scripts/run_eval.sh

第四步:训练模型

多卡请运行:

sh scripts/run_train_multi.sh

单卡请运行:

sh scripts/run_train.sh

第五步:验证预测

python predict.py \
-pretrained='output/Best_PiT' \
-img_path='images/ILSVRC2012_val_00004506.JPEG'

验证图片(类别:藏獒, id: 244)

输出结果为:

class_id: 244, prob: 9.12291145324707

对照ImageNet类别id(ImageNet数据集编号对应的类别内容),可知244为藏獒,预测结果正确。

5.代码结构

|-- paddle_pit
    |-- output              # 日志及模型文件
    |-- configs             # 参数
        |-- pit_ti.yaml
    |-- datasets
        |-- ImageNet1K      # 数据集路径
    |-- scripts             # 运行脚本
        |-- run_train.sh
        |-- run_train_multi.sh
        |-- run_eval.sh
        |-- run_eval_multi.sh
    |-- augment.py          # 数据增强
    |-- config.py           # 最底层配置文件
    |-- datasets.py         # dataset与dataloader
    |-- droppath.py         # droppath定义
    |-- losses.py           # loss定义
    |-- main_multi_gpu.py   # 多卡训练测试代码
    |-- main_single_gpu.py  # 单卡训练测试代码
    |-- mixup.py            # mixup定义
    |-- model_ema.py        # EMA定义
    |-- pit.py              # pit模型结构定义
    |-- random_erasing.py   # random_erasing定义
    |-- regnet.py           # 教师模型定义(本项目并未对此验证,仅作保留)
    |-- transforms.py       # RandomHorizontalFlip定义
    |-- utils.py            # CosineLRScheduler及AverageMeter定义
    |-- README.md
    |-- requirements.txt

6. 参考及引用

@InProceedings{Yuan_2021_ICCV,
    author    = {Yuan, Li and Chen, Yunpeng and Wang, Tao and Yu, Weihao and Shi, Yujun and Jiang, Zi-Hang and Tay, Francis E.H. and Feng, Jiashi and Yan, Shuicheng},
    title     = {Tokens-to-Token ViT: Training Vision Transformers From Scratch on ImageNet},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {558-567}
}

最后,非常感谢百度举办的飞桨论文复现挑战赛(第五期)让本人对Paddle理解更加深刻。 同时也非常感谢朱欤老师团队用Paddle实现的PaddleViT,本项目大部分代码都是从中copy来的,而仅仅实现了其与原版代码训练步骤的进一步对齐与完整的训练过程,但本人也同样受益匪浅! ♥️

Contact

Owner
Hongtao Wen
Hongtao Wen
Implementing DropPath/StochasticDepth in PyTorch

%load_ext memory_profiler Implementing Stochastic Depth/Drop Path In PyTorch DropPath is available on glasses my computer vision library! Introduction

Francesco Saverio Zuppichini 13 Jan 05, 2023
3DIAS: 3D Shape Reconstruction with Implicit Algebraic Surfaces (ICCV 2021)

3DIAS_Pytorch This repository contains the official code to reproduce the results from the paper: 3DIAS: 3D Shape Reconstruction with Implicit Algebra

Mohsen Yavartanoo 21 Dec 12, 2022
You Only Look One-level Feature (YOLOF), CVPR2021, Detectron2

You Only Look One-level Feature (YOLOF), CVPR2021 A simple, fast, and efficient object detector without FPN. This repo provides a neat implementation

qiang chen 273 Jan 03, 2023
This repo. is an implementation of ACFFNet, which is accepted for in Image and Vision Computing.

Attention-Guided-Contextual-Feature-Fusion-Network-for-Salient-Object-Detection This repo. is an implementation of ACFFNet, which is accepted for in I

5 Nov 21, 2022
An implementation of chunked, compressed, N-dimensional arrays for Python.

Zarr Latest Release Package Status License Build Status Coverage Downloads Gitter Citation What is it? Zarr is a Python package providing an implement

Zarr Developers 1.1k Dec 30, 2022
Qlib is an AI-oriented quantitative investment platform

Qlib is an AI-oriented quantitative investment platform, which aims to realize the potential, empower the research, and create the value of AI technologies in quantitative investment.

Microsoft 10.1k Dec 30, 2022
TensorFlow ROCm port

Documentation TensorFlow is an end-to-end open source platform for machine learning. It has a comprehensive, flexible ecosystem of tools, libraries, a

ROCm Software Platform 622 Jan 09, 2023
Official Repsoitory for "Mish: A Self Regularized Non-Monotonic Neural Activation Function" [BMVC 2020]

Mish: Self Regularized Non-Monotonic Activation Function BMVC 2020 (Official Paper) Notes: (Click to expand) A considerably faster version based on CU

Xa9aX ツ 1.2k Dec 29, 2022
Rainbow DQN implementation that outperforms the paper's results on 40% of games using 20x less data 🌈

Rainbow 🌈 An implementation of Rainbow DQN which outperforms the paper's (Hessel et al. 2017) results on 40% of tested games while using 20x less dat

Dominik Schmidt 31 Dec 21, 2022
GEA - Code for Guided Evolution for Neural Architecture Search

Efficient Guided Evolution for Neural Architecture Search Usage Create a conda e

6 Jan 03, 2023
PyTorch implementation of EGVSR: Efficcient & Generic Video Super-Resolution (VSR)

This is a PyTorch implementation of EGVSR: Efficcient & Generic Video Super-Resolution (VSR), using subpixel convolution to optimize the inference speed of TecoGAN VSR model. Please refer to the offi

789 Jan 04, 2023
RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching

RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching This repository contains the source code for our paper: RAFT-Stereo: Multilevel

Princeton Vision & Learning Lab 328 Jan 09, 2023
Planning from Pixels in Environments with Combinatorially Hard Search Spaces -- NeurIPS 2021

PPGS: Planning from Pixels in Environments with Combinatorially Hard Search Spaces Environment Setup We recommend pipenv for creating and managing vir

Autonomous Learning Group 11 Jun 26, 2022
Clustering is a popular approach to detect patterns in unlabeled data

Visual Clustering Clustering is a popular approach to detect patterns in unlabeled data. Existing clustering methods typically treat samples in a data

Tarek Naous 24 Nov 11, 2022
AdaSpeech 2: Adaptive Text to Speech with Untranscribed Data

AdaSpeech 2: Adaptive Text to Speech with Untranscribed Data [WIP] Unofficial Pytorch implementation of AdaSpeech 2. Requirements : All code written i

Rishikesh (ऋषिकेश) 63 Dec 28, 2022
使用yolov5训练自己数据集(详细过程)并通过flask部署

使用yolov5训练自己的数据集(详细过程)并通过flask部署 依赖库 torch torchvision numpy opencv-python lxml tqdm flask pillow tensorboard matplotlib pycocotools Windows,请使用 pycoc

HB.com 19 Dec 28, 2022
Code accompanying "Adaptive Methods for Aggregated Domain Generalization"

Adaptive Methods for Aggregated Domain Generalization (AdaClust) Official Pytorch Implementation of Adaptive Methods for Aggregated Domain Generalizat

Xavier Thomas 15 Sep 20, 2022
Code for the paper "Offline Reinforcement Learning as One Big Sequence Modeling Problem"

Trajectory Transformer Code release for Offline Reinforcement Learning as One Big Sequence Modeling Problem. Installation All python dependencies are

Michael Janner 266 Dec 27, 2022
Reproduce partial features of DeePMD-kit using PyTorch.

DeePMD-kit on PyTorch For better understand DeePMD-kit, we implement its partial features using PyTorch and expose interface consuing descriptors. Tec

Shaochen Shi 8 Dec 17, 2022
Pytorch-Swin-Unet-V2 - a modified version of Swin Unet based on Swin Transfomer V2

Swin Unet V2 Swin Unet V2 is a modified version of Swin Unet arxiv based on Swin

Chenxu Peng 26 Dec 03, 2022