Pretrained models for Jax/Haiku; MobileNet, ResNet, VGG, Xception.

Overview

Pre-trained image classification models for Jax/Haiku

Jax/Haiku Applications are deep learning models that are made available alongside pre-trained weights. These models can be used for prediction, feature extraction, and fine-tuning.

Available Models

  • MobileNetV1
  • ResNet, ResNetV2
  • VGG16, VGG19
  • Xception

Planned Releases

  • MobileNetV2, MobileNetV3
  • InceptionResNetV2, InceptionV3
  • EfficientNetV1, EfficientNetV2

Installation

Haikumodels require Python 3.7 or later.

  1. Needed libraries can be installed using "installation.txt".
  2. If Jax GPU support desired, must be installed seperately according to system needs.

Usage examples for image classification models

Classify ImageNet classes with ResNet50

import haiku as hk
import jax
import jax.numpy as jnp
from PIL import Image

import haikumodels as hm

rng = jax.random.PRNGKey(42)


def _model(images, is_training):
  net = hm.ResNet50()
  return net(images, is_training)


model = hk.transform_with_state(_model)

img_path = "elephant.jpg"
img = Image.open(img_path).resize((224, 224))

x = jnp.asarray(img, dtype=jnp.float32)
x = jnp.expand_dims(x, axis=0)
x = hm.resnet.preprocess_input(x)

params, state = model.init(rng, x, is_training=True)

preds, _ = model.apply(params, state, None, x, is_training=False)
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print("Predicted:", hm.decode_predictions(preds, top=3)[0])
# Predicted:
# [('n02504013', 'Indian_elephant', 0.8784022),
# ('n01871265', 'tusker', 0.09620289),
# ('n02504458', 'African_elephant', 0.025362419)]

Extract features with VGG16

import haiku as hk
import jax
import jax.numpy as jnp
from PIL import Image

import haikumodels as hm

rng = jax.random.PRNGKey(42)

model = hk.without_apply_rng(hk.transform(hm.VGG16(include_top=False)))

img_path = "elephant.jpg"
img = Image.open(img_path).resize((224, 224))

x = jnp.asarray(img, dtype=jnp.float32)
x = jnp.expand_dims(x, axis=0)
x = hm.vgg.preprocess_input(x)

params = model.init(rng, x)

features = model.apply(params, x)

Fine-tune Xception on a new set of classes

from typing import Callable, Any, Sequence, Optional

import optax
import haiku as hk
import jax
import jax.numpy as jnp

import haikumodels as hm

rng = jax.random.PRNGKey(42)


class Freezable_TrainState(NamedTuple):
  trainable_params: hk.Params
  non_trainable_params: hk.Params
  state: hk.State
  opt_state: optax.OptState


# create your custom top layers and include the desired pretrained model
class ft_xception(hk.Module):

  def __init__(
      self,
      classes: int,
      classifier_activation: Callable[[jnp.ndarray],
                                      jnp.ndarray] = jax.nn.softmax,
      with_bias: bool = True,
      w_init: Callable[[Sequence[int], Any], jnp.ndarray] = None,
      b_init: Callable[[Sequence[int], Any], jnp.ndarray] = None,
      name: Optional[str] = None,
  ):
    super().__init__(name=name)
    self.classifier_activation = classifier_activation

    self.xception_no_top = hm.Xception(include_top=False)
    self.dense_layer = hk.Linear(
        output_size=1024,
        with_bias=with_bias,
        w_init=w_init,
        b_init=b_init,
        name="trainable_dense_layer",
    )
    self.top_layer = hk.Linear(
        output_size=classes,
        with_bias=with_bias,
        w_init=w_init,
        b_init=b_init,
        name="trainable_top_layer",
    )

  def __call__(self, inputs: jnp.ndarray, is_training: bool):
    out = self.xception_no_top(inputs, is_training)
    out = jnp.mean(out, axis=(1, 2))
    out = self.dense_layer(out)
    out = jax.nn.relu(out)
    out = self.top_layer(out)
    out = self.classifier_activation(out)


# use `transform_with_state` if models has batchnorm in it
# else use `transform` and then `without_apply_rng`
def _model(images, is_training):
  net = ft_xception(classes=200)
  return net(images, is_training)


model = hk.transform_with_state(_model)

# create your desired optimizer using Optax or alternatives
opt = optax.rmsprop(learning_rate=1e-4, momentum=0.90)


# this function will initialize params and state
# use the desired keyword to divide params to trainable and non_trainable
def initial_state(x_y, nonfreeze_key="trainable"):
  x, _ = x_y
  params, state = model.init(rng, x, is_training=True)

  trainable_params, non_trainable_params = hk.data_structures.partition(
      lambda m, n, p: nonfreeze_key in m, params)

  opt_state = opt.init(params)

  return Freezable_TrainState(trainable_params, non_trainable_params, state,
                              opt_state)


train_state = initial_state(next(gen_x_y))


# create your own custom loss function as desired
def loss_function(trainable_params, non_trainable_params, state, x_y):
  x, y = x_y
  params = hk.data_structures.merge(trainable_params, non_trainable_params)
  y_, state = model.apply(params, state, None, x, is_training=True)

  cce = categorical_crossentropy(y, y_)

  return cce, state


# to update params and optimizer, a train_step function must be created
@jax.jit
def train_step(train_state: Freezable_TrainState, x_y):
  trainable_params, non_trainable_params, state, opt_state = train_state
  trainable_params_grads, _ = jax.grad(loss_function,
                                       has_aux=True)(trainable_params,
                                                     non_trainable_params,
                                                     state, x_y)

  updates, new_opt_state = opt.update(trainable_params_grads, opt_state)
  new_trainable_params = optax.apply_updates(trainable_params, updates)

  train_state = Freezable_TrainState(new_trainable_params, non_trainable_params,
                                     state, new_opt_state)
  return train_state


# train the model on the new data for few epochs
train_state = train_step(train_state, next(gen_x_y))

# after training is complete it possible to merge
# trainable and non_trainable params to use for prediction
trainable_params, non_trainable_params, state, _ = train_state
params = hk.data_structures.merge(trainable_params, non_trainable_params)
preds, _ = model.apply(params, state, None, x, is_training=False)
You might also like...
3D ResNet Video Classification accelerated by TensorRT
3D ResNet Video Classification accelerated by TensorRT

Activity Recognition TensorRT Perform video classification using 3D ResNets trained on Kinetics-400 dataset and accelerated with TensorRT P.S Click on

improvement of CLIP features over the traditional resnet features on the visual question answering, image captioning, navigation and visual entailment tasks.

CLIP-ViL In our paper "How Much Can CLIP Benefit Vision-and-Language Tasks?", we show the improvement of CLIP features over the traditional resnet fea

PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal Convolutions for Action Recognition"

R2Plus1D-PyTorch PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal

PyTorch implementation of MoCo v3 for self-supervised ResNet and ViT.

MoCo v3 for Self-supervised ResNet and ViT Introduction This is a PyTorch implementation of MoCo v3 for self-supervised ResNet and ViT. The original M

Reproduces ResNet-V3 with pytorch
Reproduces ResNet-V3 with pytorch

ResNeXt.pytorch Reproduces ResNet-V3 (Aggregated Residual Transformations for Deep Neural Networks) with pytorch. Tried on pytorch 1.6 Trains on Cifar

DeepLab resnet v2 model in pytorch

pytorch-deeplab-resnet DeepLab resnet v2 model implementation in pytorch. The architecture of deepLab-ResNet has been replicated exactly as it is from

Reproduce ResNet-v2(Identity Mappings in Deep Residual Networks) with MXNet
Reproduce ResNet-v2(Identity Mappings in Deep Residual Networks) with MXNet

Reproduce ResNet-v2 using MXNet Requirements Install MXNet on a machine with CUDA GPU, and it's better also installed with cuDNN v5 Please fix the ran

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

In this project we use both Resnet and Self-attention layer for cat, dog and flower classification.
In this project we use both Resnet and Self-attention layer for cat, dog and flower classification.

cdf_att_classification classes = {0: 'cat', 1: 'dog', 2: 'flower'} In this project we use both Resnet and Self-attention layer for cdf-Classification.

Comments
  • Expected top-1 test accuracy

    Expected top-1 test accuracy

    Hi

    This is a fantastic project! The released checkpoints are super helpful!

    I am wondering what's the top-1 test accuracy that one should get using the released ResNet-50 checkpoints. I am able to reach 0.749 using the my own ImageNet dataloader implemented via Tensorflow Datasets. Is the number close to your results?

    BTW, it would also be very helpful if you could release your training and dataloading code for these models!

    Thanks,

    opened by xidulu 2
  • Fitting issue

    Fitting issue

    I was trying to use a few of your pre-trained models, in particular the ResNet50 and VGG16 for features extraction, but unfortunately I didn't manage to fit on the Nvidia Titan X with 12GB of VRAM my question is which GPU did you use for training, how much VRAM I need for use them?

    For the VGG16 the system was asking me for 4 more GB and for the ResNet50 about 20 more

    Thanks.

    opened by mattiadutto 1
Owner
Alper Baris CELIK
Alper Baris CELIK
Deep universal probabilistic programming with Python and PyTorch

Getting Started | Documentation | Community | Contributing Pyro is a flexible, scalable deep probabilistic programming library built on PyTorch. Notab

7.7k Dec 30, 2022
A map update dataset and benchmark

MUNO21 MUNO21 is a dataset and benchmark for machine learning methods that automatically update and maintain digital street map datasets. Previous dat

16 Nov 30, 2022
the code of the paper: Recurrent Multi-view Alignment Network for Unsupervised Surface Registration (CVPR 2021)

RMA-Net This repo is the implementation of the paper: Recurrent Multi-view Alignment Network for Unsupervised Surface Registration (CVPR 2021). Paper

Wanquan Feng 205 Nov 09, 2022
Tesla Light Show xLights Guide With python

Tesla Light Show xLights Guide Welcome to the Tesla Light Show xLights guide! You can create and run your own light shows on Tesla vehicles. Running a

Tesla, Inc. 2.5k Dec 29, 2022
Code for "FGR: Frustum-Aware Geometric Reasoning for Weakly Supervised 3D Vehicle Detection", ICRA 2021

FGR This repository contains the python implementation for paper "FGR: Frustum-Aware Geometric Reasoning for Weakly Supervised 3D Vehicle Detection"(I

Yi Wei 31 Dec 08, 2022
Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR)

This is the official implementation of our paper Personalized Transfer of User Preferences for Cross-domain Recommendation (PTUPCDR), which has been accepted by WSDM2022.

Yongchun Zhu 81 Dec 29, 2022
这个开源项目主要是对经典的时间序列预测算法论文进行复现,模型主要参考自GluonTS,框架主要参考自Informer

Time Series Research with Torch 这个开源项目主要是对经典的时间序列预测算法论文进行复现,模型主要参考自GluonTS,框架主要参考自Informer。 建立原因 相较于mxnet和TF,Torch框架中的神经网络层需要提前指定输入维度: # 建立线性层 TensorF

Chi Zhang 85 Dec 29, 2022
Voice of Pajlada with model and weights.

Pajlada TTS Stripped down version of ForwardTacotron (https://github.com/as-ideas/ForwardTacotron) with pretrained weights for Pajlada's (https://gith

6 Sep 03, 2021
Code of the lileonardo team for the 2021 Emotion and Theme Recognition in Music task of MediaEval 2021

Emotion and Theme Recognition in Music The repository contains code for the submission of the lileonardo team to the 2021 Emotion and Theme Recognitio

Vincent Bour 8 Aug 02, 2022
RepVGG: Making VGG-style ConvNets Great Again

RepVGG: Making VGG-style ConvNets Great Again (PyTorch) This is a super simple ConvNet architecture that achieves over 80% top-1 accuracy on ImageNet

2.8k Jan 04, 2023
OREO: Object-Aware Regularization for Addressing Causal Confusion in Imitation Learning (NeurIPS 2021)

OREO: Object-Aware Regularization for Addressing Causal Confusion in Imitation Learning (NeurIPS 2021) Video demo We here provide a video demo from co

20 Nov 25, 2022
Source code of SIGIR2021 Paper 'One Chatbot Per Person: Creating Personalized Chatbots based on Implicit Profiles'

DHAP Source code of SIGIR2021 Long Paper: One Chatbot Per Person: Creating Personalized Chatbots based on Implicit User Profiles . Preinstallation Fir

ZYMa 32 Dec 06, 2022
[CVPR 2021] "The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models" Tianlong Chen, Jonathan Frankle, Shiyu Chang, Sijia Liu, Yang Zhang, Michael Carbin, Zhangyang Wang

The Lottery Tickets Hypothesis for Supervised and Self-supervised Pre-training in Computer Vision Models Codes for this paper The Lottery Tickets Hypo

VITA 59 Dec 28, 2022
Code and datasets for the paper "Combining Events and Frames using Recurrent Asynchronous Multimodal Networks for Monocular Depth Prediction" (RA-L, 2021)

Combining Events and Frames using Recurrent Asynchronous Multimodal Networks for Monocular Depth Prediction This is the code for the paper Combining E

Robotics and Perception Group 69 Dec 26, 2022
This is a Tensorflow implementation of Learning to See in the Dark in CVPR 2018

Learning-to-See-in-the-Dark This is a Tensorflow implementation of Learning to See in the Dark in CVPR 2018, by Chen Chen, Qifeng Chen, Jia Xu, and Vl

5.3k Jan 01, 2023
Node for thenewboston digital currency network.

Project setup For project setup see INSTALL.rst Community Join the community to stay updated on the most recent developments, project roadmaps, and ra

thenewboston 27 Jul 08, 2022
Buffon’s needle: one of the oldest problems in geometric probability

Buffon-s-Needle Buffon’s needle is one of the oldest problems in geometric proba

3 Feb 18, 2022
TensorFlow-based implementation of "Pyramid Scene Parsing Network".

PSPNet_tensorflow Important Code is fine for inference. However, the training code is just for reference and might be only used for fine-tuning. If yo

HsuanKung Yang 323 Dec 20, 2022
NLU Dataset Diagnostics

NLU Dataset Diagnostics This repository contains data and scripts to reproduce the results from our paper: Aarne Talman, Marianna Apidianaki, Stergios

Language Technology at the University of Helsinki 1 Jul 20, 2022
Stochastic Extragradient: General Analysis and Improved Rates

Stochastic Extragradient: General Analysis and Improved Rates This repository is the official implementation of the paper "Stochastic Extragradient: G

Hugo Berard 4 Nov 11, 2022