Feature extraction made simple with torchextractor

Overview

torchextractor: PyTorch Intermediate Feature Extraction

PyPI - Python Version PyPI Read the Docs Upload Python Package GitHub

Introduction

Too many times some model definitions get remorselessly copy-pasted just because the forward function does not return what the person expects. You provide module names and torchextractor takes care of the extraction for you.It's never been easier to extract feature, add an extra loss or plug another head to a network. Ler us know what amazing things you build with torchextractor!

Installation

pip install torchextractor  # stable
pip install git+https://github.com/antoinebrl/torchextractor.git  # latest

Requirements:

  • Python >= 3.6+
  • torch >= 1.4.0

Usage

import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)
model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"])
dummy_input = torch.rand(7, 3, 224, 224)
model_output, features = model(dummy_input)
feature_shapes = {name: f.shape for name, f in features.items()}
print(feature_shapes)

# {
#   'layer1': torch.Size([1, 64, 56, 56]),
#   'layer2': torch.Size([1, 128, 28, 28]),
#   'layer3': torch.Size([1, 256, 14, 14]),
#   'layer4': torch.Size([1, 512, 7, 7]),
# }

See more examples Binder Open In Colab

Read the documentation

FAQ

• How do I know the names of the modules?

You can print all module names like this:

tx.list_module_names(model)

# OR

for name, module in model.named_modules():
    print(name)

• Why do some operations not get listed?

It is not possible to add hooks if operations are not defined as modules. Therefore, F.relu cannot be captured but nn.Relu() can.

• How can I avoid listing all relevant modules?

You can specify a custom filtering function to hook the relevant modules:

# Hook everything !
module_filter_fn = lambda module, name: True

# Capture of all modules inside first layer
module_filter_fn = lambda module, name: name.startswith("layer1")

# Focus on all convolutions
module_filter_fn = lambda module, name: isinstance(module, torch.nn.Conv2d)

model = tx.Extractor(model, module_filter_fn=module_filter_fn)

• Is it compatible with ONNX?

tx.Extractor is compatible with ONNX! This means you can also access intermediate features maps after the export.

Pro-tip: name the output nodes by using output_names when calling torch.onnx.export.

• Is it compatible with TorchScript?

Not yet, but we are working on it. Compiling registered hook of a module was just recently added in PyTorch v1.8.0.

• "One more thing!" 😉

By default we capture the latest output of the relevant modules, but you can specify your own custom operations.

For example, to accumulate features over 10 forward passes you can do the following:

import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)

def capture_fn(module, input, output, module_name, feature_maps):
    if module_name not in feature_maps:
        feature_maps[module_name] = []
    feature_maps[module_name].append(output)

extractor = tx.Extractor(model, ["layer3", "layer4"], capture_fn=capture_fn)

for i in range(20):
    for i in range(10):
        x = torch.rand(7, 3, 224, 224)
        model(x)
    feature_maps = extractor.collect()

    # Do your stuffs here

    # Discard collected elements
    extractor.clear_placeholder()

Contributing

All feedbacks and contributions are welcomed. Feel free to report an issue or to create a pull request!

If you want to get hands-on:

  1. (Fork and) clone the repo.
  2. Create a virtual environment: virtualenv -p python3 .venv && source .venv/bin/activate
  3. Install dependencies: pip install -r requirements.txt && pip install -r requirements-dev.txt
  4. Hook auto-formatting tools: pre-commit install
  5. Hack as much as you want!
  6. Run tests: python -m unittest discover -vs ./tests/
  7. Share your work and create a pull request.

To Build documentation:

cd docs
pip install requirements.txt
make html
You might also like...
Deep Image Search is an AI-based image search engine that includes deep transfor learning features Extraction and tree-based vectorized search.
Deep Image Search is an AI-based image search engine that includes deep transfor learning features Extraction and tree-based vectorized search.

Deep Image Search - AI-Based Image Search Engine Deep Image Search is an AI-based image search engine that includes deep transfer learning features Ex

Cross-media Structured Common Space for Multimedia Event Extraction (ACL2020)
Cross-media Structured Common Space for Multimedia Event Extraction (ACL2020)

Cross-media Structured Common Space for Multimedia Event Extraction Table of Contents Overview Requirements Data Quickstart Citation Overview The code

Source code for paper "Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling", AAAI 2021

ATLOP Code for AAAI 2021 paper Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling. If you make use of this co

Training data extraction on GPT-2

Training data extraction from GPT-2 This repository contains code for extracting training data from GPT-2, following the approach outlined in the foll

This repository contains the code for our fast polygonal building extraction from overhead images pipeline.
This repository contains the code for our fast polygonal building extraction from overhead images pipeline.

Polygonal Building Segmentation by Frame Field Learning We add a frame field output to an image segmentation neural network to improve segmentation qu

Adversarial Robustness Toolbox (ART) - Python Library for Machine Learning Security - Evasion, Poisoning, Extraction, Inference - Red and Blue Teams
Adversarial Robustness Toolbox (ART) - Python Library for Machine Learning Security - Evasion, Poisoning, Extraction, Inference - Red and Blue Teams

Adversarial Robustness Toolbox (ART) is a Python library for Machine Learning Security. ART provides tools that enable developers and researchers to defend and evaluate Machine Learning models and applications against the adversarial threats of Evasion, Poisoning, Extraction, and Inference. ART supports all popular machine learning frameworks (TensorFlow, Keras, PyTorch, MXNet, scikit-learn, XGBoost, LightGBM, CatBoost, GPy, etc.), all data types (images, tables, audio, video, etc.) and machine learning tasks (classification, object detection, speech recognition, generation, certification, etc.).

Implementation for our AAAI2021 paper (Entity Structure Within and Throughout: Modeling Mention Dependencies for Document-Level Relation Extraction).
Implementation for our AAAI2021 paper (Entity Structure Within and Throughout: Modeling Mention Dependencies for Document-Level Relation Extraction).

SSAN Introduction This is the pytorch implementation of the SSAN model (see our AAAI2021 paper: Entity Structure Within and Throughout: Modeling Menti

An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks
An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks

AnalyticMesh Analytic Marching is an exact meshing solution from neural networks. Compared to standard methods, it completely avoids geometric and top

[ACL 20] Probing Linguistic Features of Sentence-level Representations in Neural Relation Extraction

REval Table of Contents Introduction Overview Requirements Installation Probing Usage Citation License 🎓 Introduction REval is a simple framework for

Comments
  • Only extracting part of the intermediate feature with DataParallel

    Only extracting part of the intermediate feature with DataParallel

    Hi @antoinebrl,

    I am using torch.nn.DataParallel on a 2-GPU machine with a batch size of N. Data parallel training will split the input data batch into 2 pieces sequentially and sends them to GPUs.

    When using torchextractor to obtain the intermediate feature, the input data size and the output size are both N as expected, but the feature size becomes N/2. Does this mean we only extract the features of one GPU? I'm not sure because I didn't find an exact match.

    Can you please explain why this happens? Maybe the normal behavior is returning features from all GPUs or from a specified one?

    A minimal example to reproduce:

    import torch
    import torchvision
    import torchextractor as tx
    
    model = torchvision.models.resnet18(pretrained=True)
    model_gpu = torch.nn.DataParallel(torchvision.models.resnet18(pretrained=True))
    model_gpu.cuda()
    
    model = tx.Extractor(model, ["layer1"])
    model_gpu = tx.Extractor(model_gpu, ["module.layer1"])
    dummy_input = torch.rand(8, 3, 224, 224)
    _, features = model(dummy_input)
    _, features_gpu = model_gpu(dummy_input)
    feature_shapes = {name: f.shape for name, f in features.items()}
    print(feature_shapes)
    feature_shapes_gpu = {name: f.shape for name, f in features_gpu.items()}
    print(feature_shapes_gpu)
    
    # {'layer1': torch.Size([8, 64, 56, 56])}
    # {'module.layer1': torch.Size([4, 64, 56, 56])}
    
    opened by wydwww 5
Releases(v0.3.0)
A Robust Non-IoU Alternative to Non-Maxima Suppression in Object Detection

Confluence: A Robust Non-IoU Alternative to Non-Maxima Suppression in Object Detection 1. 介绍 用以替代 NMS,在所有 bbox 中挑选出最优的集合。 NMS 仅考虑了 bbox 的得分,然后根据 IOU 来

44 Sep 15, 2022
Api for getting bin info and getting encrypted card details for adyen.

Bin Info And Adyen Cse Enc Python api for getting bin info and getting encrypted

Roldex Stark 8 Dec 30, 2022
Multiple Object Tracking with Yolov5!

Tracking with yolov5 This implementation is for who need to tracking multi-object only with detector. You can easily track mult-object with your well

9 Nov 08, 2022
Codes for “A Deeply Supervised Attention Metric-Based Network and an Open Aerial Image Dataset for Remote Sensing Change Detection”

DSAMNet The pytorch implementation for "A Deeply-supervised Attention Metric-based Network and an Open Aerial Image Dataset for Remote Sensing Change

Mengxi Liu 41 Dec 14, 2022
LBK 26 Dec 28, 2022
CFNet: Cascade and Fused Cost Volume for Robust Stereo Matching(CVPR2021)

CFNet(CVPR 2021) This is the implementation of the paper CFNet: Cascade and Fused Cost Volume for Robust Stereo Matching, CVPR 2021, Zhelun Shen, Yuch

106 Dec 28, 2022
Serving PyTorch 1.0 Models as a Web Server in C++

Serving PyTorch Models in C++ This repository contains various examples to perform inference using PyTorch C++ API. Run git clone https://github.com/W

Onur Kaplan 223 Jan 04, 2023
UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation

UnivNet UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation. Training python train.py --c

Rishikesh (ऋषिकेश) 55 Dec 26, 2022
The code for our paper "NSP-BERT: A Prompt-based Zero-Shot Learner Through an Original Pre-training Task —— Next Sentence Prediction"

The code for our paper "NSP-BERT: A Prompt-based Zero-Shot Learner Through an Original Pre-training Task —— Next Sentence Prediction"

Sun Yi 201 Nov 21, 2022
Some pvbatch (paraview) scripts for postprocessing OpenFOAM data

pvbatchForFoam Some pvbatch (paraview) scripts for postprocessing OpenFOAM data For every script there is a help message available: pvbatch pv_state_s

Morev Ilya 2 Oct 26, 2022
Patch-Diffusion Code (AAAI2022)

Patch-Diffusion This is an official PyTorch implementation of "Patch Diffusion: A General Module for Face Manipulation Detection" in AAAI2022. Require

H 7 Nov 02, 2022
Liecasadi - liecasadi implements Lie groups operation written in CasADi

liecasadi liecasadi implements Lie groups operation written in CasADi, mainly di

Artificial and Mechanical Intelligence 14 Nov 05, 2022
Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition

Zen-NAS: A Zero-Shot NAS for High-Performance Deep Image Recognition How Fast Compare to Other Zero-Shot NAS Proxies on CIFAR-10/100 Pre-trained Model

190 Dec 29, 2022
Iowa Project - My second project done at General Assembly, focused on feature engineering and understanding Linear Regression as a concept

Project 2 - Ames Housing Data and Kaggle Challenge PROBLEM STATEMENT Inferring or Predicting? What's more valuable for a housing model? When creating

Adam Muhammad Klesc 1 Jan 03, 2022
Decoding the Protein-ligand Interactions Using Parallel Graph Neural Networks

Decoding the Protein-ligand Interactions Using Parallel Graph Neural Networks Requirements python 0.10+ rdkit 2020.03.3.0 biopython 1.78 openbabel 2.4

Neeraj Kumar 3 Nov 23, 2022
Totally Versatile Miscellanea for Pytorch

Totally Versatile Miscellania for PyTorch Thomas Viehmann [email protected] Thi

Thomas Viehmann 428 Dec 28, 2022
Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch.

SE3 Transformer - Pytorch Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 resu

Phil Wang 207 Dec 23, 2022
This is an official repository of CLGo: Learning to Predict 3D Lane Shape and Camera Pose from a Single Image via Geometry Constraints

CLGo This is an official repository of CLGo: Learning to Predict 3D Lane Shape and Camera Pose from a Single Image via Geometry Constraints An earlier

刘芮金 32 Dec 20, 2022
Torch implementation of various types of GAN (e.g. DCGAN, ALI, Context-encoder, DiscoGAN, CycleGAN, EBGAN, LSGAN)

gans-collection.torch Torch implementation of various types of GANs (e.g. DCGAN, ALI, Context-encoder, DiscoGAN, CycleGAN, EBGAN). Note that EBGAN and

Minchul Shin 53 Jan 22, 2022
[NeurIPS 2021] Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods

Large Scale Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple Methods Large Scale Learning on Non-Homophilous Graphs: New Benchmark

60 Jan 03, 2023