PyTorch common framework to accelerate network implementation, training and validation

Overview

pytorch-framework

PyTorch common framework to accelerate network implementation, training and validation.

This framework is inspired by works from MMLab, which modularize the data, network, loss, metric, etc. to make the framework to be flexible, easy to modify and to extend.

How to use

# install necessary libs
pip install -r requirements.txt

The framework contains six different subfolders:

  • networks: all networks should be implemented under the networks folder with {NAME}_network.py filename.
  • datasets: all datasets should be implemented under the datasets folder with {NAME}_dataset.py filename.
  • losses: all losses should be implemented under the losses folder with {NAME}_loss.py filename.
  • metrics: all metrics should be implemented under the metrics folder with {NAME}_metric.py filename.
  • models: all models should be implemented under the models folder with {NAME}_model.py filename.
  • utils: all util functions should be implemented under the utils folder with {NAME}_util.py filename.

The training and validation procedure can be defined in the specified .yaml file.

# training 
CUDA_VISIBLE_DEVICES=gpu_ids python train.py --opt options/train.yaml

# validation/test
CUDA_VISIBLE_DEVICES=gpu_ids python test.py --opt options/test.yaml

In the .yaml file for training, you can define all the things related to training such as the experiment name, model, dataset, network, loss, optimizer, metrics and other hyper-parameters. Here is an example to train VGG16 for image classification:

# general setting
name: vgg_train
backend: dp # DataParallel
type: ClassifierModel
num_gpu: auto

# path to resume network
path:
  resume_state: ~

# datasets
datasets:
  train_dataset:
    name: TrainDataset
    type: ImageNet
    data_root: ../data/train_data
  val_dataset:
    name: ValDataset
    type: ImageNet
    data_root: ../data/val_data
  # setting for train dataset
  batch_size: 8

# network setting
networks:
  classifier:
    type: VGG16
    num_classes: 1000

# training setting
train:
  total_iter: 10000
  optims:
    classifier:
      type: Adam
      lr: 1.0e-4
  schedulers:
    classifier:
      type: none
  losses:
    ce_loss:
      type: CrossEntropyLoss

# validation setting
val:
  val_freq: 10000

# log setting
logger:
  print_freq: 100
  save_checkpoint_freq: 10000

In the .yaml file for validation, you can define all the things related to validation such as: model, dataset, metrics. Here is an example:

# general setting
name: test
backend: dp # DataParallel
type: ClassifierModel
num_gpu: auto
manual_seed: 1234

# path
path:
  resume_state: experiments/train/models/final.pth
  resume: false

# datasets
datasets:
  val_dataset:
    name: ValDataset
    type: ImageNet
    data_root: ../data/test_data

# network setting
networks:
  classifier:
    type: VGG
    num_classes: 1000

# validation setting
val:
  metrics:
    accuracy:
      type: calculate_accuracy

Framework Details

The core of the framework is the BaseModel in the base_model.py. The BaseModel controls the whole training/validation procedure from initialization over training/validation iteration to results saving.

  • Initialization: In the model initialization, it will read the configuration in the .yaml file and construct the corresponding networks, datasets, losses, optimizers, metrics, etc.
  • Training/Validation: In the training/validation procedure, you can refer the training process in the train.py and the validation process in the test.py.
  • Results saving: The model will automatically save the state_dict for networks, optimizers and other hyperparameters during the training.

The configuration of the framework is down by Register in the registry.py. The Register has a object map (key-value pair). The key is the name of the object, the value is the class of the object. There are total 4 different registers for networks, datasets, losses and metrics. Here is an example to register a new network:

import torch
import torch.nn as nn

from utils.registry import NETWORK_REGISTRY

@NETWORK_REGISTRY.register()
class MyNet(nn.Module):
  ...
Owner
Dongliang Cao
Dongliang Cao
Implementation for On Provable Benefits of Depth in Training Graph Convolutional Networks

Implementation for On Provable Benefits of Depth in Training Graph Convolutional Networks Setup This implementation is based on PyTorch = 1.0.0. Smal

Weilin Cong 8 Oct 28, 2022
Multiview Neural Surface Reconstruction by Disentangling Geometry and Appearance

Multiview Neural Surface Reconstruction by Disentangling Geometry and Appearance Project Page | Paper | Data This repository contains an implementatio

Lior Yariv 521 Dec 30, 2022
MLJetReconstruction - using machine learning to reconstruct jets for CMS

MLJetReconstruction - using machine learning to reconstruct jets for CMS The C++ data extraction code used here was based heavily on that foundv here.

ALPhA Davidson 0 Nov 17, 2021
K-Means Clustering and Hierarchical Clustering Unsupervised Learning Solution in Python3.

Unsupervised Learning - K-Means Clustering and Hierarchical Clustering - The Heritage Foundation's Economic Freedom Index Analysis 2019 - By David Sal

David Salako 1 Jan 12, 2022
A CROSS-MODAL FUSION NETWORK BASED ON SELF-ATTENTION AND RESIDUAL STRUCTURE FOR MULTIMODAL EMOTION RECOGNITION

CFN-SR A CROSS-MODAL FUSION NETWORK BASED ON SELF-ATTENTION AND RESIDUAL STRUCTURE FOR MULTIMODAL EMOTION RECOGNITION The audio-video based multimodal

skeleton 15 Sep 26, 2022
This repo holds codes of the ICCV21 paper: Visual Alignment Constraint for Continuous Sign Language Recognition.

VAC_CSLR This repo holds codes of the paper: Visual Alignment Constraint for Continuous Sign Language Recognition.(ICCV 2021) [paper] Prerequisites Th

Yuecong Min 64 Dec 19, 2022
Revisiting Contrastive Methods for Unsupervised Learning of Visual Representations. [2021]

Revisiting Contrastive Methods for Unsupervised Learning of Visual Representations This repo contains the Pytorch implementation of our paper: Revisit

Wouter Van Gansbeke 80 Nov 20, 2022
Large dataset storage format for Pytorch

H5Record Large dataset ( 100G, = 1T) storage format for Pytorch (wip) Support python 3 pip install h5record Why? Writing large dataset is still a

theblackcat102 43 Oct 22, 2022
Official implementation of "Synthetic Temporal Anomaly Guided End-to-End Video Anomaly Detection" (ICCV Workshops 2021: RSL-CV).

Official PyTorch implementation of "Synthetic Temporal Anomaly Guided End-to-End Video Anomaly Detection" This is the implementation of the paper "Syn

Marcella Astrid 11 Oct 07, 2022
Multi-Joint dynamics with Contact. A general purpose physics simulator.

MuJoCo Physics MuJoCo stands for Multi-Joint dynamics with Contact. It is a general purpose physics engine that aims to facilitate research and develo

DeepMind 5.2k Jan 02, 2023
Evaluating Privacy-Preserving Machine Learning in Critical Infrastructures: A Case Study on Time-Series Classification

PPML-TSA This repository provides all code necessary to reproduce the results reported in our paper Evaluating Privacy-Preserving Machine Learning in

Dominik 1 Mar 08, 2022
Example of a Quantum LSTM

Example of a Quantum LSTM

Riccardo Di Sipio 36 Oct 31, 2022
Lightweight tool to perform MITM attack on local network

ARPSpy - A lightweight tool to perform MITM attack Using many library to perform ARP Spoof and auto-sniffing HTTP packet containing credential. (Never

MinhItachi 8 Aug 28, 2022
Novel and high-performance medical image classification pipelines are heavily utilizing ensemble learning strategies

An Analysis on Ensemble Learning optimized Medical Image Classification with Deep Convolutional Neural Networks Novel and high-performance medical ima

14 Dec 18, 2022
Adaptive Graph Convolution for Point Cloud Analysis

Adaptive Graph Convolution for Point Cloud Analysis This repository contains the implementation of AdaptConv for point cloud analysis. Adaptive Graph

64 Dec 21, 2022
GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond

GCNet for Object Detection By Yue Cao, Jiarui Xu, Stephen Lin, Fangyun Wei, Han Hu. This repo is a official implementation of "GCNet: Non-local Networ

Jerry Jiarui XU 1.1k Dec 29, 2022
An Unpaired Sketch-to-Photo Translation Model

Unpaired-Sketch-to-Photo-Translation We have released our code at https://github.com/rt219/Unsupervised-Sketch-to-Photo-Synthesis This project is the

38 Oct 28, 2022
Multimodal Descriptions of Social Concepts: Automatic Modeling and Detection of (Highly Abstract) Social Concepts evoked by Art Images

MUSCO - Multimodal Descriptions of Social Concepts Automatic Modeling of (Highly Abstract) Social Concepts evoked by Art Images This project aims to i

0 Aug 22, 2021
Official PyTorch implementation of the paper "TEMOS: Generating diverse human motions from textual descriptions"

TEMOS: TExt to MOtionS Generating diverse human motions from textual descriptions Description Official PyTorch implementation of the paper "TEMOS: Gen

Mathis Petrovich 187 Dec 27, 2022
Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes

Neural Scene Flow Fields PyTorch implementation of paper "Neural Scene Flow Fields for Space-Time View Synthesis of Dynamic Scenes", CVPR 2021 [Projec

Zhengqi Li 583 Dec 30, 2022