Pretrained models for Jax/Haiku; MobileNet, ResNet, VGG, Xception.

Overview

Pre-trained image classification models for Jax/Haiku

Jax/Haiku Applications are deep learning models that are made available alongside pre-trained weights. These models can be used for prediction, feature extraction, and fine-tuning.

Available Models

  • MobileNetV1
  • ResNet, ResNetV2
  • VGG16, VGG19
  • Xception

Planned Releases

  • MobileNetV2, MobileNetV3
  • InceptionResNetV2, InceptionV3
  • EfficientNetV1, EfficientNetV2

Installation

Haikumodels require Python 3.7 or later.

  1. Needed libraries can be installed using "installation.txt".
  2. If Jax GPU support desired, must be installed seperately according to system needs.

Usage examples for image classification models

Classify ImageNet classes with ResNet50

import haiku as hk
import jax
import jax.numpy as jnp
from PIL import Image

import haikumodels as hm

rng = jax.random.PRNGKey(42)


def _model(images, is_training):
  net = hm.ResNet50()
  return net(images, is_training)


model = hk.transform_with_state(_model)

img_path = "elephant.jpg"
img = Image.open(img_path).resize((224, 224))

x = jnp.asarray(img, dtype=jnp.float32)
x = jnp.expand_dims(x, axis=0)
x = hm.resnet.preprocess_input(x)

params, state = model.init(rng, x, is_training=True)

preds, _ = model.apply(params, state, None, x, is_training=False)
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print("Predicted:", hm.decode_predictions(preds, top=3)[0])
# Predicted:
# [('n02504013', 'Indian_elephant', 0.8784022),
# ('n01871265', 'tusker', 0.09620289),
# ('n02504458', 'African_elephant', 0.025362419)]

Extract features with VGG16

import haiku as hk
import jax
import jax.numpy as jnp
from PIL import Image

import haikumodels as hm

rng = jax.random.PRNGKey(42)

model = hk.without_apply_rng(hk.transform(hm.VGG16(include_top=False)))

img_path = "elephant.jpg"
img = Image.open(img_path).resize((224, 224))

x = jnp.asarray(img, dtype=jnp.float32)
x = jnp.expand_dims(x, axis=0)
x = hm.vgg.preprocess_input(x)

params = model.init(rng, x)

features = model.apply(params, x)

Fine-tune Xception on a new set of classes

from typing import Callable, Any, Sequence, Optional

import optax
import haiku as hk
import jax
import jax.numpy as jnp

import haikumodels as hm

rng = jax.random.PRNGKey(42)


class Freezable_TrainState(NamedTuple):
  trainable_params: hk.Params
  non_trainable_params: hk.Params
  state: hk.State
  opt_state: optax.OptState


# create your custom top layers and include the desired pretrained model
class ft_xception(hk.Module):

  def __init__(
      self,
      classes: int,
      classifier_activation: Callable[[jnp.ndarray],
                                      jnp.ndarray] = jax.nn.softmax,
      with_bias: bool = True,
      w_init: Callable[[Sequence[int], Any], jnp.ndarray] = None,
      b_init: Callable[[Sequence[int], Any], jnp.ndarray] = None,
      name: Optional[str] = None,
  ):
    super().__init__(name=name)
    self.classifier_activation = classifier_activation

    self.xception_no_top = hm.Xception(include_top=False)
    self.dense_layer = hk.Linear(
        output_size=1024,
        with_bias=with_bias,
        w_init=w_init,
        b_init=b_init,
        name="trainable_dense_layer",
    )
    self.top_layer = hk.Linear(
        output_size=classes,
        with_bias=with_bias,
        w_init=w_init,
        b_init=b_init,
        name="trainable_top_layer",
    )

  def __call__(self, inputs: jnp.ndarray, is_training: bool):
    out = self.xception_no_top(inputs, is_training)
    out = jnp.mean(out, axis=(1, 2))
    out = self.dense_layer(out)
    out = jax.nn.relu(out)
    out = self.top_layer(out)
    out = self.classifier_activation(out)


# use `transform_with_state` if models has batchnorm in it
# else use `transform` and then `without_apply_rng`
def _model(images, is_training):
  net = ft_xception(classes=200)
  return net(images, is_training)


model = hk.transform_with_state(_model)

# create your desired optimizer using Optax or alternatives
opt = optax.rmsprop(learning_rate=1e-4, momentum=0.90)


# this function will initialize params and state
# use the desired keyword to divide params to trainable and non_trainable
def initial_state(x_y, nonfreeze_key="trainable"):
  x, _ = x_y
  params, state = model.init(rng, x, is_training=True)

  trainable_params, non_trainable_params = hk.data_structures.partition(
      lambda m, n, p: nonfreeze_key in m, params)

  opt_state = opt.init(params)

  return Freezable_TrainState(trainable_params, non_trainable_params, state,
                              opt_state)


train_state = initial_state(next(gen_x_y))


# create your own custom loss function as desired
def loss_function(trainable_params, non_trainable_params, state, x_y):
  x, y = x_y
  params = hk.data_structures.merge(trainable_params, non_trainable_params)
  y_, state = model.apply(params, state, None, x, is_training=True)

  cce = categorical_crossentropy(y, y_)

  return cce, state


# to update params and optimizer, a train_step function must be created
@jax.jit
def train_step(train_state: Freezable_TrainState, x_y):
  trainable_params, non_trainable_params, state, opt_state = train_state
  trainable_params_grads, _ = jax.grad(loss_function,
                                       has_aux=True)(trainable_params,
                                                     non_trainable_params,
                                                     state, x_y)

  updates, new_opt_state = opt.update(trainable_params_grads, opt_state)
  new_trainable_params = optax.apply_updates(trainable_params, updates)

  train_state = Freezable_TrainState(new_trainable_params, non_trainable_params,
                                     state, new_opt_state)
  return train_state


# train the model on the new data for few epochs
train_state = train_step(train_state, next(gen_x_y))

# after training is complete it possible to merge
# trainable and non_trainable params to use for prediction
trainable_params, non_trainable_params, state, _ = train_state
params = hk.data_structures.merge(trainable_params, non_trainable_params)
preds, _ = model.apply(params, state, None, x, is_training=False)
You might also like...
3D ResNet Video Classification accelerated by TensorRT
3D ResNet Video Classification accelerated by TensorRT

Activity Recognition TensorRT Perform video classification using 3D ResNets trained on Kinetics-400 dataset and accelerated with TensorRT P.S Click on

improvement of CLIP features over the traditional resnet features on the visual question answering, image captioning, navigation and visual entailment tasks.

CLIP-ViL In our paper "How Much Can CLIP Benefit Vision-and-Language Tasks?", we show the improvement of CLIP features over the traditional resnet fea

PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal Convolutions for Action Recognition"

R2Plus1D-PyTorch PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal

PyTorch implementation of MoCo v3 for self-supervised ResNet and ViT.

MoCo v3 for Self-supervised ResNet and ViT Introduction This is a PyTorch implementation of MoCo v3 for self-supervised ResNet and ViT. The original M

Reproduces ResNet-V3 with pytorch
Reproduces ResNet-V3 with pytorch

ResNeXt.pytorch Reproduces ResNet-V3 (Aggregated Residual Transformations for Deep Neural Networks) with pytorch. Tried on pytorch 1.6 Trains on Cifar

DeepLab resnet v2 model in pytorch

pytorch-deeplab-resnet DeepLab resnet v2 model implementation in pytorch. The architecture of deepLab-ResNet has been replicated exactly as it is from

Reproduce ResNet-v2(Identity Mappings in Deep Residual Networks) with MXNet
Reproduce ResNet-v2(Identity Mappings in Deep Residual Networks) with MXNet

Reproduce ResNet-v2 using MXNet Requirements Install MXNet on a machine with CUDA GPU, and it's better also installed with cuDNN v5 Please fix the ran

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

In this project we use both Resnet and Self-attention layer for cat, dog and flower classification.
In this project we use both Resnet and Self-attention layer for cat, dog and flower classification.

cdf_att_classification classes = {0: 'cat', 1: 'dog', 2: 'flower'} In this project we use both Resnet and Self-attention layer for cdf-Classification.

Comments
  • Expected top-1 test accuracy

    Expected top-1 test accuracy

    Hi

    This is a fantastic project! The released checkpoints are super helpful!

    I am wondering what's the top-1 test accuracy that one should get using the released ResNet-50 checkpoints. I am able to reach 0.749 using the my own ImageNet dataloader implemented via Tensorflow Datasets. Is the number close to your results?

    BTW, it would also be very helpful if you could release your training and dataloading code for these models!

    Thanks,

    opened by xidulu 2
  • Fitting issue

    Fitting issue

    I was trying to use a few of your pre-trained models, in particular the ResNet50 and VGG16 for features extraction, but unfortunately I didn't manage to fit on the Nvidia Titan X with 12GB of VRAM my question is which GPU did you use for training, how much VRAM I need for use them?

    For the VGG16 the system was asking me for 4 more GB and for the ResNet50 about 20 more

    Thanks.

    opened by mattiadutto 1
Owner
Alper Baris CELIK
Alper Baris CELIK
Misc YOLOL scripts for use in the Starbase space sandbox videogame

starbase-misc Misc YOLOL scripts for use in the Starbase space sandbox videogame. Each directory contains standalone YOLOL scripts. They don't really

4 Oct 17, 2021
E2EC: An End-to-End Contour-based Method for High-Quality High-Speed Instance Segmentation

E2EC: An End-to-End Contour-based Method for High-Quality High-Speed Instance Segmentation E2EC: An End-to-End Contour-based Method for High-Quality H

zhangtao 146 Dec 29, 2022
This is the code repository for the paper "Identification of the Generalized Condorcet Winner in Multi-dueling Bandits" (NeurIPS 2021).

Code Repository for the Paper "Identification of the Generalized Condorcet Winner in Multi-dueling Bandits" (To appear in: Proceedings of NeurIPS20

1 Oct 03, 2022
GT4SD, an open-source library to accelerate hypothesis generation in the scientific discovery process.

The GT4SD (Generative Toolkit for Scientific Discovery) is an open-source platform to accelerate hypothesis generation in the scientific discovery process. It provides a library for making state-of-t

Generative Toolkit 4 Scientific Discovery 142 Dec 24, 2022
This repo contains the code required to train the multivariate time-series Transformer.

Multi-Variate Time-Series Transformer This repo contains the code required to train the multivariate time-series Transformer. Download the data The No

Gregory Duthé 4 Nov 24, 2022
Interacting Two-Hand 3D Pose and Shape Reconstruction from Single Color Image (ICCV 2021)

Interacting Two-Hand 3D Pose and Shape Reconstruction from Single Color Image Interacting Two-Hand 3D Pose and Shape Reconstruction from Single Color

75 Dec 02, 2022
Course about deep learning for computer vision and graphics co-developed by YSDA and Skoltech.

Deep Vision and Graphics This repo supplements course "Deep Vision and Graphics" taught at YSDA @fall'21. The course is the successor of "Deep Learnin

Yandex School of Data Analysis 160 Jan 02, 2023
Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary Differential Equations

ODE GAN (Prototype) in PyTorch Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary

Somshubra Majumdar 15 Feb 10, 2022
Official repository for "PAIR: Planning and Iterative Refinement in Pre-trained Transformers for Long Text Generation"

pair-emnlp2020 Official repository for the paper: Xinyu Hua and Lu Wang: PAIR: Planning and Iterative Refinement in Pre-trained Transformers for Long

Xinyu Hua 31 Oct 13, 2022
My coursework for Machine Learning (2021 Spring) at National Taiwan University (NTU)

Machine Learning 2021 Machine Learning (NTU EE 5184, Spring 2021) Instructor: Hung-yi Lee Course Website : (https://speech.ee.ntu.edu.tw/~hylee/ml/202

100 Dec 26, 2022
Official PyTorch Implementation of paper EAN: Event Adaptive Network for Efficient Action Recognition

Official PyTorch Implementation of paper EAN: Event Adaptive Network for Efficient Action Recognition

TianYuan 27 Nov 07, 2022
Implementation of SegNet: A Deep Convolutional Encoder-Decoder Architecture for Semantic Pixel-Wise Labelling

Caffe SegNet This is a modified version of Caffe which supports the SegNet architecture As described in SegNet: A Deep Convolutional Encoder-Decoder A

Alex Kendall 1.1k Jan 02, 2023
A developer interface for creating Chat AIs for the Chai app.

ChaiPy A developer interface for creating Chat AIs for the Chai app. Usage Local development A quick start guide is available here, with a minimal exa

Chai 28 Dec 28, 2022
An efficient and easy-to-use deep learning model compression framework

TinyNeuralNetwork 简体中文 TinyNeuralNetwork is an efficient and easy-to-use deep learning model compression framework, which contains features like neura

Alibaba 441 Dec 25, 2022
Official codebase for "B-Pref: Benchmarking Preference-BasedReinforcement Learning" contains scripts to reproduce experiments.

B-Pref Official codebase for B-Pref: Benchmarking Preference-BasedReinforcement Learning contains scripts to reproduce experiments. Install conda env

48 Dec 20, 2022
Graph Robustness Benchmark: A scalable, unified, modular, and reproducible benchmark for evaluating the adversarial robustness of Graph Machine Learning.

Homepage | Paper | Datasets | Leaderboard | Documentation Graph Robustness Benchmark (GRB) provides scalable, unified, modular, and reproducible evalu

THUDM 66 Dec 22, 2022
PyTorch implementation of "VRT: A Video Restoration Transformer"

VRT: A Video Restoration Transformer Jingyun Liang, Jiezhang Cao, Yuchen Fan, Kai Zhang, Rakesh Ranjan, Yawei Li, Radu Timofte, Luc Van Gool Computer

Jingyun Liang 837 Jan 09, 2023
This is the code for the paper "Contrastive Clustering" (AAAI 2021)

Contrastive Clustering (CC) This is the code for the paper "Contrastive Clustering" (AAAI 2021) Dependency python=3.7 pytorch=1.6.0 torchvision=0.8

Yunfan Li 210 Dec 30, 2022
[CVPR-2021] UnrealPerson: An adaptive pipeline for costless person re-identification

UnrealPerson: An Adaptive Pipeline for Costless Person Re-identification In our paper (arxiv), we propose a novel pipeline, UnrealPerson, that decreas

ZhangTianyu 70 Oct 10, 2022
[ICML 2020] "When Does Self-Supervision Help Graph Convolutional Networks?" by Yuning You, Tianlong Chen, Zhangyang Wang, Yang Shen

When Does Self-Supervision Help Graph Convolutional Networks? PyTorch implementation for When Does Self-Supervision Help Graph Convolutional Networks?

Shen Lab at Texas A&M University 106 Nov 11, 2022