Runtime type annotations for the shape, dtype etc. of PyTorch Tensors.

Overview

torchtyping

Type annotations for a tensor's shape, dtype, names, ...

Turn this:

def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # x has shape (batch, x_channels)
    # y has shape (batch, y_channels)
    # return has shape (batch, x_channels, y_channels)

    return x.unsqueeze(-1) * y.unsqueeze(-2)

into this:

def batch_outer_product(x:   TensorType["batch", "x_channels"],
                        y:   TensorType["batch", "y_channels"]
                        ) -> TensorType["batch", "x_channels", "y_channels"]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)

with programmatic checking that the shape (dtype, ...) specification is met.

Bye-bye bugs! Say hello to enforced, clear documentation of your code.

If (like me) you find yourself littering your code with comments like # x has shape (batch, hidden_state) or statements like assert x.shape == y.shape , just to keep track of what shape everything is, then this is for you.


Installation

pip install torchtyping

Requires Python 3.7+ and PyTorch 1.7.0+.

Usage

torchtyping allows for type annotating:

  • shape: size, number of dimensions;
  • dtype (float, integer, etc.);
  • layout (dense, sparse);
  • names of dimensions as per named tensors;
  • arbitrary number of batch dimensions with ...;
  • ...plus anything else you like, as torchtyping is highly extensible.

If typeguard is (optionally) installed then at runtime the types can be checked to ensure that the tensors really are of the advertised shape, dtype, etc.

# EXAMPLE

from torch import rand
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked

patch_typeguard()  # use before @typechecked

@typechecked
def func(x: TensorType["batch"],
         y: TensorType["batch"]) -> TensorType["batch"]:
    return x + y

func(rand(3), rand(3))  # works
func(rand(3), rand(1))
# TypeError: Dimension 'batch' of inconsistent size. Got both 1 and 3.

typeguard also has an import hook that can be used to automatically test an entire module, without needing to manually add @typeguard.typechecked decorators.

If you're not using typeguard then torchtyping.patch_typeguard() can be omitted altogether, and torchtyping just used for documentation purposes. If you're not already using typeguard for your regular Python programming, then strongly consider using it. It's a great way to squash bugs. Both typeguard and torchtyping also integrate with pytest, so if you're concerned about any performance penalty then they can be enabled during tests only.

API

torchtyping.TensorType[shape, dtype, layout, details]

The core of the library.

Each of shape, dtype, layout, details are optional.

  • The shape argument can be any of:
    • An int: the dimension must be of exactly this size. If it is -1 then any size is allowed.
    • A str: the size of the dimension passed at runtime will be bound to this name, and all tensors checked that the sizes are consistent.
    • A ...: An arbitrary number of dimensions of any sizes.
    • A str: int pair (technically it's a slice), combining both str and int behaviour. (Just a str on its own is equivalent to str: -1.)
    • A str: ... pair, in which case the multiple dimensions corresponding to ... will be bound to the name specified by str, and again checked for consistency between arguments.
    • None, which when used in conjunction with is_named below, indicates a dimension that must not have a name in the sense of named tensors.
    • A None: int pair, combining both None and int behaviour. (Just a None on its own is equivalent to None: -1.)
    • A typing.Any: Any size is allowed for this dimension (equivalent to -1).
    • Any tuple of the above. For example.TensorType["batch": ..., "length": 10, "channels", -1]. If you just want to specify the number of dimensions then use for example TensorType[-1, -1, -1] for a three-dimensional tensor.
  • The dtype argument can be any of:
    • torch.float32, torch.float64 etc.
    • int, bool, float, which are converted to their corresponding PyTorch types. float is specifically interpreted as torch.get_default_dtype(), which is usually float32.
  • The layout argument can be either torch.strided or torch.sparse_coo, for dense and sparse tensors respectively.
  • The details argument offers a way to pass an arbitrary number of additional flags that customise and extend torchtyping. Two flags are built-in by default. torchtyping.is_named causes the names of tensor dimensions to be checked, and torchtyping.is_float can be used to check that arbitrary floating point types are passed in. (Rather than just a specific one as with e.g. TensorType[torch.float32].) For discussion on how to customise torchtyping with your own details, see the further documentation.

Check multiple things at once by just putting them all together inside a single []. For example TensorType["batch": ..., "length", "channels", float, is_named].

torchtyping.patch_typeguard()

torchtyping integrates with typeguard to perform runtime type checking. torchtyping.patch_typeguard() should be called at the global level, and will patch typeguard to check TensorTypes.

This function is safe to run multiple times. (It does nothing after the first run).

  • If using @typeguard.typechecked, then torchtyping.patch_typeguard() should be called any time before using @typeguard.typechecked. For example you could call it at the start of each file using torchtyping.
  • If using typeguard.importhook.install_import_hook, then torchtyping.patch_typeguard() should be called any time before defining the functions you want checked. For example you could call torchtyping.patch_typeguard() just once, at the same time as the typeguard import hook. (The order of the hook and the patch doesn't matter.)
  • If you're not using typeguard then torchtyping.patch_typeguard() can be omitted altogether, and torchtyping just used for documentation purposes.
pytest --torchtyping-patch-typeguard

torchtyping offers a pytest plugin to automatically run torchtyping.patch_typeguard() before your tests. pytest will automatically discover the plugin, you just need to pass the --torchtyping-patch-typeguard flag to enable it. Packages can then be passed to typeguard as normal, either by using @typeguard.typechecked, typeguard's import hook, or the pytest flag --typeguard-packages="your_package_here".

Further documentation

See the further documentation for:

  • FAQ;
    • Including flake8 and mypy compatibility;
  • How to write custom extensions to torchtyping;
  • Resources and links to other libraries and materials on this topic;
  • More examples.
Owner
Patrick Kidger
Maths+ML PhD student at Oxford. Neural ODEs+SDEs+CDEs, time series, rough analysis. (Also ice skating, martial arts and scuba diving!)
Patrick Kidger
Deep-learning X-Ray Micro-CT image enhancement, pore-network modelling and continuum modelling

EDSR modelling A Github repository for deep-learning image enhancement, pore-network and continuum modelling from X-Ray Micro-CT images. The repositor

Samuel Jackson 7 Nov 03, 2022
PyTorch version implementation of DORN

DORN_PyTorch This is a PyTorch version implementation of DORN Reference H. Fu, M. Gong, C. Wang, K. Batmanghelich and D. Tao: Deep Ordinal Regression

Zilin.Zhang 3 Apr 27, 2022
Implementation of Ag-Grid component for Streamlit

streamlit-aggrid AgGrid is an awsome grid for web frontend. More information in https://www.ag-grid.com/. Consider purchasing a license from Ag-Grid i

Pablo Fonseca 556 Dec 31, 2022
[CVPR 2021] NormalFusion: Real-Time Acquisition of Surface Normals for High-Resolution RGB-D Scanning

NormalFusion: Real-Time Acquisition of Surface Normals for High-Resolution RGB-D Scanning Project Page | Paper | Supplemental material #1 | Supplement

KAIST VCLAB 49 Nov 24, 2022
This program can detect your face and add an Christams hat on the top of your head

Auto_Christmas This program can detect your face and add a Christmas hat to the top of your head. just run the Auto_Christmas.py, then you can see the

3 Dec 22, 2021
Code for the paper "Improved Techniques for Training GANs"

Status: Archive (code is provided as-is, no updates expected) improved-gan code for the paper "Improved Techniques for Training GANs" MNIST, SVHN, CIF

OpenAI 2.2k Jan 01, 2023
Codes for NeurIPS 2021 paper "Adversarial Neuron Pruning Purifies Backdoored Deep Models"

Adversarial Neuron Pruning Purifies Backdoored Deep Models Code for NeurIPS 2021 "Adversarial Neuron Pruning Purifies Backdoored Deep Models" by Dongx

Dongxian Wu 31 Dec 11, 2022
This repository contains code and data for "On the Multimodal Person Verification Using Audio-Visual-Thermal Data"

trimodal_person_verification This repository contains the code, and preprocessed dataset featured in "A Study of Multimodal Person Verification Using

ISSAI 7 Aug 31, 2022
official implementation for the paper "Simplifying Graph Convolutional Networks"

Simplifying Graph Convolutional Networks Updates As pointed out by #23, there was a subtle bug in our preprocessing code for the reddit dataset. After

Tianyi 727 Jan 01, 2023
The project covers common metrics for super-resolution performance evaluation.

Super-Resolution Performance Evaluation Code The project covers common metrics for super-resolution performance evaluation. Metrics support The script

xmy 10 Aug 03, 2022
ManiSkill-Learn is a framework for training agents on SAPIEN Open-Source Manipulation Skill Challenge (ManiSkill Challenge), a large-scale learning-from-demonstrations benchmark for object manipulation.

ManiSkill-Learn ManiSkill-Learn is a framework for training agents on SAPIEN Open-Source Manipulation Skill Challenge, a large-scale learning-from-dem

Hao Su's Lab, UCSD 48 Dec 30, 2022
Official page of Struct-MDC (RA-L'22 with IROS'22 option); Depth completion from Visual-SLAM using point & line features

Struct-MDC (click the above buttons for redirection!) Official page of "Struct-MDC: Mesh-Refined Unsupervised Depth Completion Leveraging Structural R

Urban Robotics Lab. @ KAIST 37 Dec 22, 2022
Pixel-Perfect Structure-from-Motion with Featuremetric Refinement (ICCV 2021, Oral)

Pixel-Perfect Structure-from-Motion (ICCV 2021 Oral) We introduce a framework that improves the accuracy of Structure-from-Motion by refining keypoint

Computer Vision and Geometry Lab 831 Dec 29, 2022
In the case of your data having only 1 channel while want to use timm models

timm_custom Description In the case of your data having only 1 channel while want to use timm models (with or without pretrained weights), run the fol

2 Nov 26, 2021
[NeurIPS2021] Code Release of K-Net: Towards Unified Image Segmentation

K-Net: Towards Unified Image Segmentation Introduction This is an official release of the paper K-Net:Towards Unified Image Segmentation. K-Net will a

Wenwei Zhang 423 Jan 02, 2023
A foreign language learning aid using a neural network to predict probability of translating foreign words

Langy Langy is a reading-focused foreign language learning aid orientated towards young children. Reading is an activity that every child knows. It is

Shona Lowden 6 Nov 17, 2021
PyTorch Implementation of Vector Quantized Variational AutoEncoders.

Pytorch implementation of VQVAE. This paper combines 2 tricks: Vector Quantization (check out this amazing blog for better understanding.) Straight-Th

Vrushank Changawala 2 Oct 06, 2021
Inverse Rendering for Complex Indoor Scenes: Shape, Spatially-Varying Lighting and SVBRDF From a Single Image

Inverse Rendering for Complex Indoor Scenes: Shape, Spatially-Varying Lighting and SVBRDF From a Single Image (Project page) Zhengqin Li, Mohammad Sha

209 Jan 05, 2023
This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

ICCV Workshop 2021 VTGAN This code is for our paper "VTGAN: Semi-supervised Retinal Image Synthesis and Disease Prediction using Vision Transformers"

Sharif Amit Kamran 25 Dec 08, 2022
A Python package for causal inference using Synthetic Controls

Synthetic Control Methods A Python package for causal inference using synthetic controls This Python package implements a class of approaches to estim

Oscar Engelbrektson 107 Dec 28, 2022