Compare neural networks by their feature similarity

Overview

PyTorch Model Compare

A tiny package to compare two neural networks in PyTorch. There are many ways to compare two neural networks, but one robust and scalable way is using the Centered Kernel Alignment (CKA) metric, where the features of the networks are compared.

Centered Kernel Alignment

Centered Kernel Alignment (CKA) is a representation similarity metric that is widely used for understanding the representations learned by neural networks. Specifically, CKA takes two feature maps / representations X and Y as input and computes their normalized similarity (in terms of the Hilbert-Schmidt Independence Criterion (HSIC)) as

CKA original version

Where K and L are similarity matrices of X and Y respectively. However, the above formula is not scalable against deep architectures and large datasets. Therefore, a minibatch version can be constructed that uses an unbiased estimator of the HSIC as

alt text

alt text

The above form of CKA is from the 2021 ICLR paper by Nguyen T., Raghu M, Kornblith S.

Getting Started

Installation

pip install torch_cka

Usage

from torch_cka import CKA
model1 = resnet18(pretrained=True)  # Or any neural network of your choice
model2 = resnet34(pretrained=True)

dataloader = DataLoader(your_dataset, 
                        batch_size=batch_size, # according to your device memory
                        shuffle=False)  # Don't forget to seed your dataloader

cka = CKA(model1, model2,
          model1_name="ResNet18",   # good idea to provide names to avoid confusion
          model2_name="ResNet34",   
          model1_layers=layer_names_resnet18, # List of layers to extract features from
          model2_layers=layer_names_resnet34, # extracts all layer features by default
          device='cuda')

cka.compare(dataloader) # secondary dataloader is optional

results = cka.export()  # returns a dict that contains model names, layer names
                        # and the CKA matrix

Examples

torch_cka can be used with any pytorch model (subclass of nn.Module) and can be used with pretrained models available from popular sources like torchHub, timm, huggingface etc. Some examples of where this package can come in handy are illustrated below.

Comparing the effect of Depth

A simple experiment is to analyse the features learned by two architectures of the same family - ResNets but of different depths. Taking two ResNets - ResNet18 and ResNet34 - pre-trained on the Imagenet dataset, we can analyse how they produce their features on, say CIFAR10 for simplicity. This comparison is shown as a heatmap below.

alt text

We see high degree of similarity between the two models in lower layers as they both learn similar representations from the data. However at higher layers, the similarity reduces as the deeper model (ResNet34) learn higher order features which the is elusive to the shallower model (ResNet18). Yet, they do indeed have certain similarity in their last fc layer which acts as the feature classifier.

Comparing Two Similar Architectures

Another way of using CKA is in ablation studies. We can go further than those ablation studies that only focus on resultant performance and employ CKA to study the internal representations. Case in point - ResNet50 and WideResNet50 (k=2). WideResNet50 has the same architecture as ResNet50 except having wider residual bottleneck layers (by a factor of 2 in this case).

alt text

We clearly notice that the learned features are indeed different after the first few layers. The width has a more pronounced effect in deeper layers as compared to the earlier layers as both networks seem to learn similar features in the initial layers.

As a bonus, here is a comparison between ViT and the latest SOTA model Swin Transformer pretrained on ImageNet22k.

alt text

Comparing quite different architectures

CNNs have been analysed a lot over the past decade since AlexNet. We somewhat know what sort of features they learn across their layers (through visualizations) and we have put them to good use. One interesting approach is to compare these understandable features with newer models that don't permit easy visualizations (like recent vision transformer architectures) and study them. This has indeed been a hot research topic (see Raghu et.al 2021).

alt text

Comparing Datasets

Yet another application is to compare two datasets - preferably two versions of the data. This is especially useful in production where data drift is a known issue. If you have an updated version of a dataset, you can study how your model will perform on it by comparing the representations of the datasets. This can be more telling about actual performance than simply comparing the datasets directly.

This can also be quite useful in studying the performance of a model on downstream tasks and fine-tuning. For instance, if the CKA score is high for some features on different datasets, then those can be frozen during fine-tuning. As an example, the following figure compares the features of a pretrained Resnet50 on the Imagenet test data and the VOC dataset. Clearly, the pretrained features have little correlation with the VOC dataset. Therefore, we have to resort to fine-tuning to get at least satisfactory results.

alt text

Tips

  • If your model is large (lots of layers or large feature maps), try to extract from select layers. This is to avoid out of memory issues.
  • If you still want to compare the entire feature map, you can run it multiple times with few layers at each iteration and export your data using cka.export(). The exported data can then be concatenated to produce the full CKA matrix.
  • Give proper model names to avoid confusion when interpreting the results. The code automatically extracts the model name for you by default, but it is good practice to label the models according to your use case.
  • When providing your dataloader(s) to the compare() function, it is important that they are seeded properly for reproducibility.
  • When comparing datasets, be sure to set drop_last=True when building the dataloader. This resolves shape mismatch issues - especially in differently sized datasets.

Citation

If you use this repo in your project or research, please cite as -

@software{subramanian2021torch_cka,
    author={Anand Subramanian},
    title={torch_cka},
    url={https://github.com/AntixK/PyTorch-Model-Compare},
    year={2021}
}
Owner
Anand Krishnamoorthy
Research Engineer
Anand Krishnamoorthy
StarGAN v2 - Official PyTorch Implementation (CVPR 2020)

StarGAN v2 - Official PyTorch Implementation StarGAN v2: Diverse Image Synthesis for Multiple Domains Yunjey Choi*, Youngjung Uh*, Jaejun Yoo*, Jung-W

Clova AI Research 3.1k Jan 09, 2023
Rust bindings for the C++ api of PyTorch.

tch-rs Rust bindings for the C++ api of PyTorch. The goal of the tch crate is to provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorc

Laurent Mazare 2.3k Dec 30, 2022
PyTorch Implementation of Small Lesion Segmentation in Brain MRIs with Subpixel Embedding (ORAL, MICCAIW 2021)

Small Lesion Segmentation in Brain MRIs with Subpixel Embedding PyTorch implementation of Small Lesion Segmentation in Brain MRIs with Subpixel Embedd

22 Oct 21, 2022
This is an example of a reproducible modelling project

An example of a reproducible modelling project What are we doing? This example was created for the 2021 fall lecture series of Stanford's Center for O

Armin Thomas 2 Oct 26, 2021
Attempt at implementation of a simple GAN using Keras

Simple GAN This is my attempt to make a wrapper class for a GAN in keras which can be used to abstract the whole architecture process. Simple GAN Over

Deven96 7 May 23, 2019
Keras-1D-NN-Classifier

Keras-1D-NN-Classifier This code is based on the reference codes linked below. reference 1, reference 2 This code is for 1-D array data classification

Jae-Hoon Shim 6 May 18, 2021
This is a simple framework to make object detection dataset very quickly

FastAnnotation Table of contents General info Requirements Setup General info This is a simple framework to make object detection dataset very quickly

Serena Tetart 1 Jan 24, 2022
Differentiable Annealed Importance Sampling (DAIS)

Differentiable Annealed Importance Sampling (DAIS) This repository contains the code to reproduce the DAIS results from the paper Differentiable Annea

Guodong Zhang 6 Dec 26, 2021
Official repository of the AAAI'2022 paper "Contrast and Generation Make BART a Good Dialogue Emotion Recognizer"

CoG-BART Contrast and Generation Make BART a Good Dialogue Emotion Recognizer Quick Start: To run the model on test sets of four datasets, Download th

39 Dec 24, 2022
Data Preparation, Processing, and Visualization for MoVi Data

MoVi-Toolbox Data Preparation, Processing, and Visualization for MoVi Data, https://www.biomotionlab.ca/movi/ MoVi is a large multipurpose dataset of

Saeed Ghorbani 51 Nov 27, 2022
Official code implementation for "Personalized Federated Learning using Hypernetworks"

Personalized Federated Learning using Hypernetworks This is an official implementation of Personalized Federated Learning using Hypernetworks paper. [

Aviv Shamsian 121 Dec 25, 2022
Simple image captioning model - CLIP prefix captioning.

Simple image captioning model - CLIP prefix captioning.

688 Jan 04, 2023
Python version of the amazing Reaction Mechanism Generator (RMG).

Reaction Mechanism Generator (RMG) Description This repository contains the Python version of Reaction Mechanism Generator (RMG), a tool for automatic

Reaction Mechanism Generator 284 Dec 27, 2022
EMNLP'2021: Simple Entity-centric Questions Challenge Dense Retrievers

EntityQuestions This repository contains the EntityQuestions dataset as well as code to evaluate retrieval results from the the paper Simple Entity-ce

Princeton Natural Language Processing 119 Sep 28, 2022
N-Omniglot is a large neuromorphic few-shot learning dataset

N-Omniglot [Paper] || [Dataset] N-Omniglot is a large neuromorphic few-shot learning dataset. It reconstructs strokes of Omniglot as videos and uses D

11 Dec 05, 2022
Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments (CoRL 2020)

Motion Planner Augmented Reinforcement Learning for Robot Manipulation in Obstructed Environments [Project website] [Paper] This project is a PyTorch

Cognitive Learning for Vision and Robotics (CLVR) lab @ USC 49 Nov 28, 2022
BossNAS: Exploring Hybrid CNN-transformers with Block-wisely Self-supervised Neural Architecture Search

BossNAS This repository contains PyTorch evaluation code, retraining code and pretrained models of our paper: BossNAS: Exploring Hybrid CNN-transforme

Changlin Li 127 Dec 26, 2022
This project is a re-implementation of MASTER: Multi-Aspect Non-local Network for Scene Text Recognition by MMOCR

This project is a re-implementation of MASTER: Multi-Aspect Non-local Network for Scene Text Recognition by MMOCR,which is an open-source toolbox based on PyTorch. The overall architecture will be sh

Jianquan Ye 82 Nov 17, 2022
Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.

Equinox Callable PyTrees and filtered JIT/grad transformations = neural networks in JAX Equinox brings more power to your model building in JAX. Repr

Patrick Kidger 909 Dec 30, 2022
Python Single Object Tracking Evaluation

pysot-toolkit The purpose of this repo is to provide evaluation API of Current Single Object Tracking Dataset, including VOT2016 VOT2018 VOT2018-LT OT

348 Dec 22, 2022