Fine-tune pretrained Convolutional Neural Networks with PyTorch

Overview

Fine-tune pretrained Convolutional Neural Networks with PyTorch.

PyPI CircleCI codecov.io

Features

  • Gives access to the most popular CNN architectures pretrained on ImageNet.
  • Automatically replaces classifier on top of the network, which allows you to train a network with a dataset that has a different number of classes.
  • Allows you to use images with any resolution (and not only the resolution that was used for training the original model on ImageNet).
  • Allows adding a Dropout layer or a custom pooling layer.

Supported architectures and models

From the torchvision package:

  • ResNet (resnet18, resnet34, resnet50, resnet101, resnet152)
  • ResNeXt (resnext50_32x4d, resnext101_32x8d)
  • DenseNet (densenet121, densenet169, densenet201, densenet161)
  • Inception v3 (inception_v3)
  • VGG (vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn)
  • SqueezeNet (squeezenet1_0, squeezenet1_1)
  • MobileNet V2 (mobilenet_v2)
  • ShuffleNet v2 (shufflenet_v2_x0_5, shufflenet_v2_x1_0)
  • AlexNet (alexnet)
  • GoogLeNet (googlenet)

From the Pretrained models for PyTorch package:

  • ResNeXt (resnext101_32x4d, resnext101_64x4d)
  • NASNet-A Large (nasnetalarge)
  • NASNet-A Mobile (nasnetamobile)
  • Inception-ResNet v2 (inceptionresnetv2)
  • Dual Path Networks (dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107)
  • Inception v4 (inception_v4)
  • Xception (xception)
  • Squeeze-and-Excitation Networks (senet154, se_resnet50, se_resnet101, se_resnet152, se_resnext50_32x4d, se_resnext101_32x4d)
  • PNASNet-5-Large (pnasnet5large)
  • PolyNet (polynet)

Requirements

  • Python 3.5+
  • PyTorch 1.1+

Installation

pip install cnn_finetune

Major changes:

Version 0.4

  • Default value for pretrained argument in make_model is changed from False to True. Now call make_model('resnet18', num_classes=10) is equal to make_model('resnet18', num_classes=10, pretrained=True)

Example usage:

Make a model with ImageNet weights for 10 classes

from cnn_finetune import make_model

model = make_model('resnet18', num_classes=10, pretrained=True)

Make a model with Dropout

model = make_model('nasnetalarge', num_classes=10, pretrained=True, dropout_p=0.5)

Make a model with Global Max Pooling instead of Global Average Pooling

import torch.nn as nn

model = make_model('inceptionresnetv2', num_classes=10, pretrained=True, pool=nn.AdaptiveMaxPool2d(1))

Make a VGG16 model that takes images of size 256x256 pixels

VGG and AlexNet models use fully-connected layers, so you have to additionally pass the input size of images when constructing a new model. This information is needed to determine the input size of fully-connected layers.

model = make_model('vgg16', num_classes=10, pretrained=True, input_size=(256, 256))

Make a VGG16 model that takes images of size 256x256 pixels and uses a custom classifier

import torch.nn as nn

def make_classifier(in_features, num_classes):
    return nn.Sequential(
        nn.Linear(in_features, 4096),
        nn.ReLU(inplace=True),
        nn.Linear(4096, num_classes),
    )

model = make_model('vgg16', num_classes=10, pretrained=True, input_size=(256, 256), classifier_factory=make_classifier)

Show preprocessing that was used to train the original model on ImageNet

>> model = make_model('resnext101_64x4d', num_classes=10, pretrained=True)
>> print(model.original_model_info)
ModelInfo(input_space='RGB', input_size=[3, 224, 224], input_range=[0, 1], mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
>> print(model.original_model_info.mean)
[0.485, 0.456, 0.406]

CIFAR10 Example

See examples/cifar10.py file (requires PyTorch 1.1+).

Owner
Alex Parinov
Computer Vision Architect
Alex Parinov
The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue.

The repo contains the code to train and evaluate a system which extracts relations and explanations from dialogue. How do I cite D-REX? For now, cite

Alon Albalak 6 Mar 31, 2022
The best solution of the Weather Prediction track in the Yandex Shifts challenge

yandex-shifts-weather The repository contains information about my solution for the Weather Prediction track in the Yandex Shifts challenge https://re

Ivan Yu. Bondarenko 15 Dec 18, 2022
A simple, high level, easy-to-use open source Computer Vision library for Python.

ZoomVision : Slicing Aid Detection A simple, high level, easy-to-use open source Computer Vision library for Python. Installation Installing dependenc

Nurettin Sinanoğlu 2 Mar 04, 2022
This is the repository for the AAAI 21 paper [Contrastive and Generative Graph Convolutional Networks for Graph-based Semi-Supervised Learning].

CG3 This is the repository for the AAAI 21 paper [Contrastive and Generative Graph Convolutional Networks for Graph-based Semi-Supervised Learning]. R

12 Oct 28, 2022
Implementation for "Domain-Specific Bias Filtering for Single Labeled Domain Generalization"

DSBF Introduction This repository contains the implementation code for paper: Domain-Specific Bias Filtering for Single Labeled Domain Generalization

ScottYuan 7 Jan 05, 2023
Material del curso IIC2233 Programación Avanzada 📚

Contenidos Los contenidos se organizan según la semana del semestre en que nos encontremos, y según la semana que se destina para su estudio. Los cont

IIC2233 @ UC 72 Dec 23, 2022
This project implements "virtual speed" from heart rate monito

ANT+ Virtual Stride Based Speed and Distance Monitor Overview This project imple

2 May 20, 2022
Taming Transformers for High-Resolution Image Synthesis

Taming Transformers for High-Resolution Image Synthesis CVPR 2021 (Oral) Taming Transformers for High-Resolution Image Synthesis Patrick Esser*, Robin

CompVis Heidelberg 3.5k Jan 03, 2023
TorchCV: A PyTorch-Based Framework for Deep Learning in Computer Vision

TorchCV: A PyTorch-Based Framework for Deep Learning in Computer Vision @misc{you2019torchcv, author = {Ansheng You and Xiangtai Li and Zhen Zhu a

Donny You 2.2k Jan 06, 2023
Library for converting from RGB / GrayScale image to base64 and back.

Library for converting RGB / Grayscale numpy images from to base64 and back. Installation pip install -U image_to_base_64 Conversion RGB to base 64 b

Vladimir Iglovikov 16 Aug 28, 2022
Pytorch implementation of “Recursive Non-Autoregressive Graph-to-Graph Transformer for Dependency Parsing with Iterative Refinement”

Graph-to-Graph Transformers Self-attention models, such as Transformer, have been hugely successful in a wide range of natural language processing (NL

Idiap Research Institute 40 Aug 14, 2022
How Do Adam and Training Strategies Help BNNs Optimization? In ICML 2021.

AdamBNN This is the pytorch implementation of our paper "How Do Adam and Training Strategies Help BNNs Optimization?", published in ICML 2021. In this

Zechun Liu 47 Sep 20, 2022
A DCGAN to generate anime faces using custom mined dataset

Anime-Face-GAN-Keras A DCGAN to generate anime faces using custom dataset in Keras. Dataset The dataset is created by crawling anime database websites

Pavitrakumar P 190 Jan 03, 2023
Some experiments with tennis player aging curves using Hilbert space GPs in PyMC. Only experimental for now.

NOTE: This is still being developed! Setup notes This document uses Jeff Sackmann's tennis data. You can obtain it as follows: git clone https://githu

Martin Ingram 1 Jan 20, 2022
PyTorch implementation for 3D human pose estimation

Towards 3D Human Pose Estimation in the Wild: a Weakly-supervised Approach This repository is the PyTorch implementation for the network presented in:

Xingyi Zhou 579 Dec 22, 2022
Code and data for ImageCoDe, a contextual vison-and-language benchmark

ImageCoDe This repository contains code and data for ImageCoDe: Image Retrieval from Contextual Descriptions. Data All collected descriptions for the

McGill NLP 27 Dec 02, 2022
Using PyTorch Perform intent classification using three different models to see which one is better for this task

Using PyTorch Perform intent classification using three different models to see which one is better for this task

Yoel Graumann 1 Feb 14, 2022
Image Data Augmentation in Keras

Image data augmentation is a technique that can be used to artificially expand the size of a training dataset by creating modified versions of images in the dataset.

Grace Ugochi Nneji 3 Feb 15, 2022
Outlier Exposure with Confidence Control for Out-of-Distribution Detection

OOD-detection-using-OECC This repository contains the essential code for the paper Outlier Exposure with Confidence Control for Out-of-Distribution De

Nazim Shaikh 64 Nov 02, 2022
PyExplainer: A Local Rule-Based Model-Agnostic Technique (Explainable AI)

PyExplainer PyExplainer is a local rule-based model-agnostic technique for generating explanations (i.e., why a commit is predicted as defective) of J

AI Wizards for Software Management (AWSM) Research Group 14 Nov 13, 2022