Jaxtorch (a jax nn library)

Related tags

Deep Learningjaxtorch
Overview

Jaxtorch (a jax nn library)

This is my jax based nn library. I created this because I was annoyed by the complexity and 'magic'-ness of the popular jax frameworks (flax, haiku).

The objective is to enable pytorch-like model definition and training with a minimum of magic. Simple example:

import jax
import jax.numpy as jnp
import jaxlib
import jaxtorch

# Modules are just classes that inherit jaxtorch.Module
class Linear(jaxtorch.Module):
    # They can accept any constructor parameters
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        # Parameters are represented by a Param type, which identifies
        # them, and specifies how to initialize them.
        self.weight = jaxtorch.init.glorot_normal(out_features, in_features)
        assert type(self.weight) is jaxtorch.Param
        if bias:
            self.bias = jaxtorch.init.zeros(out_features)
        else:
            self.bias = None

    # The forward function accepts cx, a Context object as the first argument
    # always. This provides random number generation as well as the parameters.
    def forward(self, cx: jaxtorch.Context, x):
        # Parameters are looked up in the context using the stored identifier.
        y = x @ jnp.transpose(cx[self.weight])
        if self.bias:
            y = y + cx[self.bias]
        return y

model = Linear(3, 3)

# You initialize the weights by passing a RNG key.
# Calling init_weights also names all the parameters in the Module tree.
params = model.init_weights(jax.random.PRNGKey(0))

# Parameters are stored in a dictionary by name.
assert type(params) is dict
assert type(params[model.weight.name]) is jaxlib.xla_extension.DeviceArray
assert model.weight.name == 'weight'

def loss(params, key):
    cx = jaxtorch.Context(params, key)
    x = jnp.array([1.0,2.0,3.0])
    y = jnp.array([4.0,5.0,6.0])
    return jnp.mean((model(cx, x) - y)**2)
f_grad = jax.value_and_grad(loss)

for _ in range(100):
    (loss, grad) = f_grad(params, jax.random.PRNGKey(0))
    params = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grad)
print(loss)
# 4.7440533e-08
Owner
nshepperd
nshepperd
Using machine learning to predict and analyze high and low reader engagement for New York Times articles posted to Facebook.

How The New York Times can increase Engagement on Facebook Using machine learning to understand characteristics of news content that garners "high" Fa

Jessica Miles 0 Sep 16, 2021
Cookiecutter PyTorch Lightning

Cookiecutter PyTorch Lightning Instructions # install cookiecutter pip install cookiecutter

Mazen 8 Nov 06, 2022
PyTorch implementation of our ICCV 2021 paper, Interpretation of Emergent Communication in Heterogeneous Collaborative Embodied Agents.

PyTorch implementation of our ICCV 2021 paper, Interpretation of Emergent Communication in Heterogeneous Collaborative Embodied Agents.

Saim Wani 4 May 08, 2022
Real-time multi-object tracker using YOLO v5 and deep sort

This repository contains a two-stage-tracker. The detections generated by YOLOv5, a family of object detection architectures and models pretrained on the COCO dataset, are passed to a Deep Sort algor

Mike 3.6k Jan 05, 2023
Generate high quality pictures. GAN. Generative Adversarial Networks

ESRGAN generate high quality pictures. GAN. Generative Adversarial Networks """ Super-resolution of CelebA using Generative Adversarial Networks. The

Lieon 1 Dec 14, 2021
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

Introduction This is a Python package available on PyPI for NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pyto

Artit 'Art' Wangperawong 5 Sep 29, 2021
GUPNet - Geometry Uncertainty Projection Network for Monocular 3D Object Detection

GUPNet This is the official implementation of "Geometry Uncertainty Projection Network for Monocular 3D Object Detection". citation If you find our wo

Yan Lu 103 Dec 28, 2022
Blind visual quality assessment on 360° Video based on progressive learning

Blind visual quality assessment on omnidirectional or 360 video (ProVQA) Blind VQA for 360° Video via Progressively Learning from Pixels, Frames and V

5 Jan 06, 2023
the code for our CVPR 2021 paper Bilateral Grid Learning for Stereo Matching Network [BGNet]

BGNet This repository contains the code for our CVPR 2021 paper Bilateral Grid Learning for Stereo Matching Network [BGNet] Environment Python 3.6.* C

3DCV developer 87 Nov 29, 2022
The official PyTorch implementation of the paper: *Xili Dai, Xiaojun Yuan, Haigang Gong, Yi Ma. "Fully Convolutional Line Parsing." *.

F-Clip — Fully Convolutional Line Parsing This repository contains the official PyTorch implementation of the paper: *Xili Dai, Xiaojun Yuan, Haigang

Xili Dai 115 Dec 28, 2022
Consensus score for tripadvisor

ContripScore ContripScore is essentially a score that combines an Internet platform rating and a consensus rating from sentiment analysis (For instanc

Pepe 1 Jan 13, 2022
banditml is a lightweight contextual bandit & reinforcement learning library designed to be used in production Python services.

banditml is a lightweight contextual bandit & reinforcement learning library designed to be used in production Python services. This library is developed by Bandit ML and ex-authors of Facebook's app

Bandit ML 51 Dec 22, 2022
A novel framework to automatically learn high-quality scanning of non-planar, complex anisotropic appearance.

appearance-scanner About This repository is an implementation of the neural network proposed in Free-form Scanning of Non-planar Appearance with Neura

Xiaohe Ma 14 Oct 18, 2022
Video-face-extractor - Video face extractor with Python

Python face extractor Setup Create the srcvideos and faces directories Put your

2 Feb 03, 2022
Code for the TIP 2021 Paper "Salient Object Detection with Purificatory Mechanism and Structural Similarity Loss"

PurNet Project for the TIP 2021 Paper "Salient Object Detection with Purificatory Mechanism and Structural Similarity Loss" Abstract Image-based salie

Jinming Su 4 Aug 25, 2022
Repository for scripts and notebooks from the book: Programming PyTorch for Deep Learning

Repository for scripts and notebooks from the book: Programming PyTorch for Deep Learning

Ian Pointer 368 Dec 17, 2022
Music Generation using Neural Networks Streamlit App

Music_Gen_Streamlit "Music Generation using Neural Networks" Streamlit App TO DO: Make a run_app.sh Introduction [~5 min] (Sohaib) Team Member names/i

Muhammad Sohaib Arshid 6 Aug 09, 2022
Simple embedding based text classifier inspired by fastText, implemented in tensorflow

FastText in Tensorflow This project is based on the ideas in Facebook's FastText but implemented in Tensorflow. However, it is not an exact replica of

Alan Patterson 306 Dec 02, 2022
House3D: A Rich and Realistic 3D Environment

House3D: A Rich and Realistic 3D Environment Yi Wu, Yuxin Wu, Georgia Gkioxari and Yuandong Tian House3D is a virtual 3D environment which consists of

Meta Research 1.1k Dec 14, 2022
Simulation of the solar system using various nummerical methods

solar-system Simulation of the solar system using various nummerical methods Download the repo Make shure matplotlib, scipy etc. are installed execute

Caspar 7 Jul 15, 2022