PyTorch implementation of GLOM

Overview

GLOM

PyTorch implementation of GLOM, Geoffrey Hinton's new idea that integrates concepts from neural fields, top-down-bottom-up processing, and attention (consensus between columns).

1. Overview

An implementation of Geoffrey Hinton's paper "How to represent part-whole hierarchies in a neural network" for MNIST Dataset.

2. Usage

2 - 1. PyTorch version

import torch
from pyglom import GLOM

model = GLOM(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
levels = model(img, iters = 12) # (1, 256, 6, 512) - (batch - patches - levels - dimension)

Pass the return_all = True keyword argument on forward, and you will be returned all the column and level states per iteration, (including the initial state, number of iterations + 1). You can then use this to attach any losses to any level outputs at any time step.

It also gives you access to all the level data across iterations for clustering, from which one can inspect for the theorized islands in the paper.

import torch
from pyglom import GLOM

model = GLOM(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
all_levels = model(img, iters = 12, return_all = True) # (13, 1, 256, 6, 512) - (time, batch, patches, levels, dimension)

# get the top level outputs after iteration 6
top_level_output = all_levels[7, :, :, -1] # (1, 256, 512) - (batch, patches, dimension)

Denoising self-supervised learning for encouraging emergence, as described by Hinton

import torch
import torch.nn.functional as F
from torch import nn
from einops.layers.torch import Rearrange

from pyglom import GLOM

model = GLOM(
    dim = 512,         # dimension
    levels = 6,        # number of levels
    image_size = 224,  # image size
    patch_size = 14    # patch size
)

img = torch.randn(1, 3, 224, 224)
noised_img = img + torch.randn_like(img)

all_levels = model(noised_img, return_all = True)

patches_to_images = nn.Sequential(
    nn.Linear(512, 14 * 14 * 3),
    Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = 14, p2 = 14, h = (224 // 14))
)

top_level = all_levels[7, :, :, -1]  # get the top level embeddings after iteration 6
recon_img = patches_to_images(top_level)

# do self-supervised learning by denoising

loss = F.mse_loss(img, recon_img)
loss.backward()

You can pass in the state of the column and levels back into the model to continue where you left off (perhaps if you are processing consecutive frames of a slow video, as mentioned in the paper)

import torch
from pyglom import GLOM

model = GLOM(
    dim = 512,
    levels = 6,
    image_size = 224,
    patch_size = 14
)

img1 = torch.randn(1, 3, 224, 224)
img2 = torch.randn(1, 3, 224, 224)
img3 = torch.randn(1, 3, 224, 224)

levels1 = model(img1, iters = 12)                   # image 1 for 12 iterations
levels2 = model(img2, levels = levels1, iters = 10) # image 2 for 10 iteratoins
levels3 = model(img3, levels = levels2, iters = 6)  # image 3 for 6 iterations

2 - 2. PyTorch-Lightning version

The pyglom also provides the GLOM model that is implemented with PyTorch-Lightning.

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import os
from pytorch_lightning.callbacks import ModelCheckpoint


from pyglom.glom import LightningGLOM


dataset = MNIST(os.getcwd(), download=True, transform=transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
]))
train, val = random_split(dataset, [55000, 5000])

glom = LightningGLOM(
    dim=256,         # dimension
    levels=6,        # number of levels
    image_size=256,  # image size
    patch_size=16,   # patch size
    img_channels=1
)

gpus = torch.cuda.device_count()
trainer = pl.Trainer(gpus=gpus, max_epochs=5)
trainer.fit(glom, DataLoader(train, batch_size=8, num_workers=2), DataLoader(val, batch_size=8, num_workers=2))

3. ToDo

  • contrastive / consistency regularization of top-ish levels

4. Citations

@misc{hinton2021represent,
    title   = {How to represent part-whole hierarchies in a neural network}, 
    author  = {Geoffrey Hinton},
    year    = {2021},
    eprint  = {2102.12627},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives

HashNeRF-pytorch Instant-NGP recently introduced a Multi-resolution Hash Encodin

Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, and DVC.

NN Template Generic template to bootstrap your PyTorch project. Click on Use this Template and avoid writing boilerplate code for: PyTorch Lightning,

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.

Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

A bunch of random PyTorch models using PyTorch's C++ frontend
A bunch of random PyTorch models using PyTorch's C++ frontend

PyTorch Deep Learning Models using the C++ frontend Gettting started Clone the repo 1. https://github.com/mrdvince/pytorchcpp 2. cd fashionmnist or

Releases(0.0.3)
Owner
Yeonwoo Sung
2020-09-21 ~ 2022-06-20 RoK (Korea) Air Force
Yeonwoo Sung
Direct LiDAR Odometry: Fast Localization with Dense Point Clouds

Direct LiDAR Odometry: Fast Localization with Dense Point Clouds DLO is a lightweight and computationally-efficient frontend LiDAR odometry solution w

VECTR at UCLA 369 Dec 30, 2022
Your interactive network visualizing dashboard

Your interactive network visualizing dashboard Documentation: Here What is Jaal Jaal is a python based interactive network visualizing tool built usin

Mohit 177 Jan 04, 2023
Springer Link Download Module for Python

♞ pupalink A simple Python module to search and download books from SpringerLink. 🧪 This project is still in an early stage of development. Expect br

Pupa Corp. 18 Nov 21, 2022
MultiLexNorm 2021 competition system from ÚFAL

ÚFAL at MultiLexNorm 2021: Improving Multilingual Lexical Normalization by Fine-tuning ByT5 David Samuel & Milan Straka Charles University Faculty of

ÚFAL 13 Jun 28, 2022
Model Zoo for AI Model Efficiency Toolkit

We provide a collection of popular neural network models and compare their floating point and quantized performance.

Qualcomm Innovation Center 137 Jan 03, 2023
Codes for the ICCV'21 paper "FREE: Feature Refinement for Generalized Zero-Shot Learning"

FREE This repository contains the reference code for the paper "FREE: Feature Refinement for Generalized Zero-Shot Learning". [arXiv][Paper] 1. Prepar

Shiming Chen 28 Jul 29, 2022
CVPR2021 Content-Aware GAN Compression

Content-Aware GAN Compression [ArXiv] Paper accepted to CVPR2021. @inproceedings{liu2021content, title = {Content-Aware GAN Compression}, auth

52 Nov 06, 2022
Deep deconfounded recommender (Deep-Deconf) for paper "Deep causal reasoning for recommendations"

Deep Causal Reasoning for Recommender Systems The codes are associated with the following paper: Deep Causal Reasoning for Recommendations, Yaochen Zh

Yaochen Zhu 22 Oct 15, 2022
The repository is for safe reinforcement learning baselines.

Safe-Reinforcement-Learning-Baseline The repository is for Safe Reinforcement Learning (RL) research, in which we investigate various safe RL baseline

172 Dec 19, 2022
[CVPR 2020] 3D Photography using Context-aware Layered Depth Inpainting

[CVPR 2020] 3D Photography using Context-aware Layered Depth Inpainting [Paper] [Project Website] [Google Colab] We propose a method for converting a

Virginia Tech Vision and Learning Lab 6.2k Jan 01, 2023
PyTorch implementation of the Deep SLDA method from our CVPRW-2020 paper "Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis"

Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis This is a PyTorch implementation of the Deep Streaming Linear Discriminant

Tyler Hayes 41 Dec 25, 2022
QAT(quantize aware training) for classification with MQBench

MQBench Quantization Aware Training with PyTorch I am using MQBench(Model Quantization Benchmark)(http://mqbench.tech/) to quantize the model for depl

Ling Zhang 29 Nov 18, 2022
The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution.

WSRGlow The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution. Audio sa

Kexun Zhang 96 Jan 03, 2023
RetinaFace: Deep Face Detection Library in TensorFlow for Python

RetinaFace is a deep learning based cutting-edge facial detector for Python coming with facial landmarks.

Sefik Ilkin Serengil 512 Dec 29, 2022
OpenMMLab's Next Generation Video Understanding Toolbox and Benchmark

Introduction English | 简体中文 MMAction2 is an open-source toolbox for video understanding based on PyTorch. It is a part of the OpenMMLab project. The m

OpenMMLab 2.7k Jan 07, 2023
Codebase for Time-series Generative Adversarial Networks (TimeGAN)

Codebase for Time-series Generative Adversarial Networks (TimeGAN)

Jinsung Yoon 532 Dec 31, 2022
A large-scale face dataset for face parsing, recognition, generation and editing.

CelebAMask-HQ [Paper] [Demo] CelebAMask-HQ is a large-scale face image dataset that has 30,000 high-resolution face images selected from the CelebA da

switchnorm 1.7k Dec 26, 2022
Near-Optimal Sparse Allreduce for Distributed Deep Learning (published in PPoPP'22)

Near-Optimal Sparse Allreduce for Distributed Deep Learning (published in PPoPP'22) Ok-Topk is a scheme for distributed training with sparse gradients

Shigang Li 9 Oct 29, 2022
Simple API for UCI Machine Learning Dataset Repository (search, download, analyze)

A simple API for working with University of California, Irvine (UCI) Machine Learning (ML) repository Table of Contents Introduction About Page of the

Tirthajyoti Sarkar 223 Dec 05, 2022
A PyTorch implementation of the architecture of Mask RCNN

EDIT (AS OF 4th NOVEMBER 2019): This implementation has multiple errors and as of the date 4th, November 2019 is insufficient to be utilized as a reso

Sai Himal Allu 975 Dec 30, 2022