PyTorch common framework to accelerate network implementation, training and validation

Overview

pytorch-framework

PyTorch common framework to accelerate network implementation, training and validation.

This framework is inspired by works from MMLab, which modularize the data, network, loss, metric, etc. to make the framework to be flexible, easy to modify and to extend.

How to use

# install necessary libs
pip install -r requirements.txt

The framework contains six different subfolders:

  • networks: all networks should be implemented under the networks folder with {NAME}_network.py filename.
  • datasets: all datasets should be implemented under the datasets folder with {NAME}_dataset.py filename.
  • losses: all losses should be implemented under the losses folder with {NAME}_loss.py filename.
  • metrics: all metrics should be implemented under the metrics folder with {NAME}_metric.py filename.
  • models: all models should be implemented under the models folder with {NAME}_model.py filename.
  • utils: all util functions should be implemented under the utils folder with {NAME}_util.py filename.

The training and validation procedure can be defined in the specified .yaml file.

# training 
CUDA_VISIBLE_DEVICES=gpu_ids python train.py --opt options/train.yaml

# validation/test
CUDA_VISIBLE_DEVICES=gpu_ids python test.py --opt options/test.yaml

In the .yaml file for training, you can define all the things related to training such as the experiment name, model, dataset, network, loss, optimizer, metrics and other hyper-parameters. Here is an example to train VGG16 for image classification:

# general setting
name: vgg_train
backend: dp # DataParallel
type: ClassifierModel
num_gpu: auto

# path to resume network
path:
  resume_state: ~

# datasets
datasets:
  train_dataset:
    name: TrainDataset
    type: ImageNet
    data_root: ../data/train_data
  val_dataset:
    name: ValDataset
    type: ImageNet
    data_root: ../data/val_data
  # setting for train dataset
  batch_size: 8

# network setting
networks:
  classifier:
    type: VGG16
    num_classes: 1000

# training setting
train:
  total_iter: 10000
  optims:
    classifier:
      type: Adam
      lr: 1.0e-4
  schedulers:
    classifier:
      type: none
  losses:
    ce_loss:
      type: CrossEntropyLoss

# validation setting
val:
  val_freq: 10000

# log setting
logger:
  print_freq: 100
  save_checkpoint_freq: 10000

In the .yaml file for validation, you can define all the things related to validation such as: model, dataset, metrics. Here is an example:

# general setting
name: test
backend: dp # DataParallel
type: ClassifierModel
num_gpu: auto
manual_seed: 1234

# path
path:
  resume_state: experiments/train/models/final.pth
  resume: false

# datasets
datasets:
  val_dataset:
    name: ValDataset
    type: ImageNet
    data_root: ../data/test_data

# network setting
networks:
  classifier:
    type: VGG
    num_classes: 1000

# validation setting
val:
  metrics:
    accuracy:
      type: calculate_accuracy

Framework Details

The core of the framework is the BaseModel in the base_model.py. The BaseModel controls the whole training/validation procedure from initialization over training/validation iteration to results saving.

  • Initialization: In the model initialization, it will read the configuration in the .yaml file and construct the corresponding networks, datasets, losses, optimizers, metrics, etc.
  • Training/Validation: In the training/validation procedure, you can refer the training process in the train.py and the validation process in the test.py.
  • Results saving: The model will automatically save the state_dict for networks, optimizers and other hyperparameters during the training.

The configuration of the framework is down by Register in the registry.py. The Register has a object map (key-value pair). The key is the name of the object, the value is the class of the object. There are total 4 different registers for networks, datasets, losses and metrics. Here is an example to register a new network:

import torch
import torch.nn as nn

from utils.registry import NETWORK_REGISTRY

@NETWORK_REGISTRY.register()
class MyNet(nn.Module):
  ...
Owner
Dongliang Cao
Dongliang Cao
[SDM 2022] Towards Similarity-Aware Time-Series Classification

SimTSC This is the PyTorch implementation of SDM2022 paper Towards Similarity-Aware Time-Series Classification. We propose Similarity-Aware Time-Serie

Daochen Zha 49 Dec 27, 2022
The official repo of the CVPR 2021 paper Group Collaborative Learning for Co-Salient Object Detection .

GCoNet The official repo of the CVPR 2021 paper Group Collaborative Learning for Co-Salient Object Detection . Trained model Download final_gconet.pth

Qi Fan 46 Nov 17, 2022
tensorflow implementation of 'YOLO : Real-Time Object Detection'

YOLO_tensorflow (Version 0.3, Last updated :2017.02.21) 1.Introduction This is tensorflow implementation of the YOLO:Real-Time Object Detection It can

Jinyoung Choi 1.7k Nov 21, 2022
NEATEST: Evolving Neural Networks Through Augmenting Topologies with Evolution Strategy Training

NEATEST: Evolving Neural Networks Through Augmenting Topologies with Evolution Strategy Training

Göktuğ Karakaşlı 16 Dec 05, 2022
Diffusion Probabilistic Models for 3D Point Cloud Generation (CVPR 2021)

Diffusion Probabilistic Models for 3D Point Cloud Generation [Paper] [Code] The official code repository for our CVPR 2021 paper "Diffusion Probabilis

Shitong Luo 323 Jan 05, 2023
CLDF dataset derived from Robbeets et al.'s "Triangulation Supports Agricultural Spread" from 2021

CLDF dataset derived from Robbeets et al.'s "Triangulation Supports Agricultural Spread" from 2021 How to cite If you use these data please cite the o

Digital Linguistics 2 Dec 20, 2021
PICARD - Parsing Incrementally for Constrained Auto-Regressive Decoding from Language Models

This is the official implementation of the following paper: Torsten Scholak, Nathan Schucher, Dzmitry Bahdanau. PICARD - Parsing Incrementally for Con

ElementAI 217 Jan 01, 2023
Coursera - Quiz & Assignment of Coursera

Coursera Assignments This repository is aimed to help Coursera learners who have difficulties in their learning process. The quiz and programming home

浅梦 828 Jan 04, 2023
Easy genetic ancestry predictions in Python

ezancestry Easily visualize your direct-to-consumer genetics next to 2500+ samples from the 1000 genomes project. Evaluate the performance of a custom

Kevin Arvai 38 Jan 02, 2023
A-ESRGAN aims to provide better super-resolution images by using multi-scale attention U-net discriminators.

A-ESRGAN: Training Real-World Blind Super-Resolution with Attention-based U-net Discriminators The authors are hidden for the purpose of double blind

77 Dec 16, 2022
Welcome to The Eigensolver Quantum School, a quantum computing crash course designed by students for students.

TEQS Welcome to The Eigensolver Quantum School, a crash course designed by students for students. The aim of this program is to take someone who has n

The Eigensolvers 53 May 18, 2022
Open source repository for the code accompanying the paper 'Non-Rigid Neural Radiance Fields Reconstruction and Novel View Synthesis of a Deforming Scene from Monocular Video'.

Non-Rigid Neural Radiance Fields This is the official repository for the project "Non-Rigid Neural Radiance Fields: Reconstruction and Novel View Synt

Facebook Research 296 Dec 29, 2022
High accurate tool for automatic faces detection with landmarks

faces_detanator High accurate tool for automatic faces detection with landmarks. The library is based on public detectors with high accuracy (TinaFace

Ihar 7 May 10, 2022
Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression based on Matrix Product Operators

Enabling Lightweight Fine-tuning for Pre-trained Language Model Compression based on Matrix Product Operators This is our Pytorch implementation for t

RUCAIBox 12 Jul 22, 2022
Source code and dataset of the paper "Contrastive Adaptive Propagation Graph Neural Networks forEfficient Graph Learning"

CAPGNN Source code and dataset of the paper "Contrastive Adaptive Propagation Graph Neural Networks forEfficient Graph Learning" Paper URL: https://ar

1 Mar 12, 2022
An open software package to develop BCI based brain and cognitive computing technology for recognizing user's intention using deep learning

An open software package to develop BCI based brain and cognitive computing technology for recognizing user's intention using deep learning

deepbci 272 Jan 08, 2023
[AAAI 2021] EMLight: Lighting Estimation via Spherical Distribution Approximation and [ICCV 2021] Sparse Needlets for Lighting Estimation with Spherical Transport Loss

EMLight: Lighting Estimation via Spherical Distribution Approximation (AAAI 2021) Update 12/2021: We release our Virtual Object Relighting (VOR) Datas

Fangneng Zhan 144 Jan 06, 2023
Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition in CVPR19

2s-AGCN Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition in CVPR19 Note PyTorch version should be 0.3! For PyTor

LShi 547 Dec 26, 2022
Here is the implementation of our paper S2VC: A Framework for Any-to-Any Voice Conversion with Self-Supervised Pretrained Representations.

S2VC Here is the implementation of our paper S2VC: A Framework for Any-to-Any Voice Conversion with Self-Supervised Pretrained Representations. In thi

81 Dec 15, 2022
Code of paper: "DropAttack: A Masked Weight Adversarial Training Method to Improve Generalization of Neural Networks"

DropAttack: A Masked Weight Adversarial Training Method to Improve Generalization of Neural Networks Abstract: Adversarial training has been proven to

倪仕文 (Shiwen Ni) 58 Nov 10, 2022