Tutorial for surrogate gradient learning in spiking neural networks

Overview

SpyTorch

A tutorial on surrogate gradient learning in spiking neural networks

Version: 0.4

DOI

This repository contains tutorial files to get you started with the basic ideas of surrogate gradient learning in spiking neural networks using PyTorch.

You find a brief introductory video accompanying these notebooks here https://youtu.be/xPYiAjceAqU

Feedback and contributions are welcome.

For more information on surrogate gradient learning please refer to:

Neftci, E.O., Mostafa, H., and Zenke, F. (2019). Surrogate Gradient Learning in Spiking Neural Networks: Bringing the Power of Gradient-based optimization to spiking neural networks. IEEE Signal Processing Magazine 36, 51–63. https://ieeexplore.ieee.org/document/8891809 preprint: https://arxiv.org/abs/1901.09948

Also see https://github.com/surrogate-gradient-learning

Copyright and license

Copyright 2019-2020 Friedemann Zenke, https://fzenke.net

This work is licensed under a Creative Commons Attribution 4.0 International License. http://creativecommons.org/licenses/by/4.0/

Comments
  • resetting with

    resetting with "out" instead of "rst"?

    • This is a comment, not an issue *

    Hi Friedemann, First of thanks a lot for these great tutorials, I've enjoyed a lot playing with them, and I've learned a lot :-) One question: in the run_snn function, why do you bother constructing the "rst" tensor? Why don't you subtract the "out" tensor, which also contains the output spikes? I've tried, and it seems to work. Just curious. Best,

    Tim

    question 
    opened by tmasquelier 8
  • Problem in SpyTorchTutorial2

    Problem in SpyTorchTutorial2

    Hello,

    It was a very nice and interesting tutorial, thank you for preparing it...

    tutorial1 haven't any problem, but in tutorial 2, some dtype problems occurred... after their fixation, training process was very slow on GTX 980 (I've run on this config some very deep model)... could you please explain your config, and also training time and response time?

    opened by ghost 6
  • Spike times shifted

    Spike times shifted

    I have the impression that the spike recordings are shifted one time step in all tutorials. Could you maybe check if this is indeed the case?

    From my understanding, time step 0 is recorded twice for the spikes, once during initialisation

      mem = torch.zeros((batch_size, nb_hidden), device=device, dtype=dtype)
      spk_rec = [mem]
    

    and once within the simulation of time step 0:

      for t in range(nb_steps):
          mthr = mem-1.0
          out = spike_fn(mthr)
          ...
          spk_rec.append(out)
    

    As a result the indeces appear shifted when comparing

    print(torch.nonzero((mem_rec-1.0) > 0.0))
    print(torch.nonzero(spk_rec))
    

    Thanks, Simon

    opened by smonsays 4
  • Software/Machine description available?

    Software/Machine description available?

    Hey Friedemann,

    thanks for making the examples available, they look very helpful. However, to make them fully reproducible I think that some additional information regarding the "technical dependencies" is needed.

    In particular, the list of used software packages (incl. version and build variant information) plus some specification about the machine hardware (CPU arch, GPUs).

    Preferably, the former could be expressed as a recipe for constructing a container (Dockerfile, or for better HPC-compatibility, a Singularity recipe), maybe even using an explicitly versioning package manager like spack.

    Cheers, Eric

    opened by muffgaga 3
  • Dataset never decompressed

    Dataset never decompressed

    Hello,

    I belive I ran into a possible issue here. Due to line 37 the evaluation in line 38 will always be false if one hasnt already got the uncompressed dataset.

    https://github.com/fzenke/spytorch/blob/9e91eceaf53f17be9e95a3743164224bdbb086bb/notebooks/utils.py#L35-L42

    If I change line 37 to: hdf5_file_path = gz_file_path[:-3] This works for me.

    Best, Aaron

    opened by AaronSpieler 1
  • propagation delay

    propagation delay

    Hi zenke, I have a question about the snn model. If I feed a spike image to a snn with L layers at time step n, the output of the last layer will be affected by the input at time step n + L - 1. In deep networks, the delay should be considered, because it will increase the whole time steps. Screen Shot 2021-12-15 at 4 50 45 PM

    opened by yizx6 1
  • Compute recurrent contribution from spikes

    Compute recurrent contribution from spikes

    Hey Friedemann,

    thank you for the very comprehensive tutorial! I have a question on the way the recurrence is computed in tutorial 4. If I understand the equation for the dynamics of the current correctly, the recurrence should be computed with the spiking neuron state:

    mthr = mem-1.0
    out = spike_fn(mthr)
    h1 = h1_from_input[:,t] + torch.einsum("ab,bc->ac", (out, v1))
    

    Instead in tutorial 4, a separate hidden state is kept, that ignores the spike function:

    h1 = h1_from_input[:,t] + torch.einsum("ab,bc->ac", (h1, v1))
    

    Is this done deliberately? Judging from simulating a few epochs, the two versions seem to perform similarly.

    Thank you,

    Simon

    opened by smonsays 1
  • maybe simplification

    maybe simplification

    I don't understand why the 'rst' variable exists. It seems to always be == 'out'. Changing to rst = out yields same results...

    def spike_fn(x):
        out = torch.zeros_like(x)
        out[x > 0] = 1.0
        return out
    ...
    # Here we loop over time
    for t in range(nb_steps):
        mthr = mem-1.0
        out = spike_fn(mthr) 
        rst = torch.zeros_like(mem)
        c = (mthr > 0)
        rst[c] = torch.ones_like(mem)[c] 
    
    opened by colinator 1
  • Issue in running Tutorial-4

    Issue in running Tutorial-4

    When I am running the following piece of code in Tutorial-4:

    loss_hist = train(x_train, y_train, lr=2e-4, nb_epochs=nb_epochs)

    I am getting the following error: pic3

    Can you please suggest me how to resolve this issue?

    opened by paglabhola 0
Releases(v0.3)
Owner
Friedemann Zenke
Friedemann Zenke
Distiller is an open-source Python package for neural network compression research.

Wiki and tutorials | Documentation | Getting Started | Algorithms | Design | FAQ Distiller is an open-source Python package for neural network compres

Intel Labs 4.1k Dec 28, 2022
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking"

model_based_energy_constrained_compression Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and

Haichuan Yang 16 Jun 15, 2022
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News March 3: v0.9.97 has various bug fixes and improvements: Bug fixes for NTXentLoss Efficiency improvement for AccuracyCalculator, by using torch i

Kevin Musgrave 5k Jan 02, 2023
Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Fast and Easy-to-use Distributed Graph Learning for PyTorch Geometric

Quiver Team 221 Dec 22, 2022
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 07, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
Code snippets created for the PyTorch discussion board

PyTorch misc Collection of code snippets I've written for the PyTorch discussion board. All scripts were testes using the PyTorch 1.0 preview and torc

461 Dec 26, 2022
Tutorial for surrogate gradient learning in spiking neural networks

SpyTorch A tutorial on surrogate gradient learning in spiking neural networks Version: 0.4 This repository contains tutorial files to get you started

Friedemann Zenke 203 Nov 28, 2022
PyNIF3D is an open-source PyTorch-based library for research on neural implicit functions (NIF)-based 3D geometry representation.

PyNIF3D is an open-source PyTorch-based library for research on neural implicit functions (NIF)-based 3D geometry representation. It aims to accelerate research by providing a modular design that all

Preferred Networks, Inc. 96 Nov 28, 2022
A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

Fidelity Investments 56 Sep 13, 2022
Fast, general, and tested differentiable structured prediction in PyTorch

Torch-Struct: Structured Prediction Library A library of tested, GPU implementations of core structured prediction algorithms for deep learning applic

HNLP 1.1k Jan 07, 2023
Pytorch bindings for Fortran

Pytorch bindings for Fortran

Dmitry Alexeev 46 Dec 29, 2022
torch-optimizer -- collection of optimizers for Pytorch

torch-optimizer torch-optimizer -- collection of optimizers for PyTorch compatible with optim module. Simple example import torch_optimizer as optim

Nikolay Novik 2.6k Jan 03, 2023
Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

205 Jan 02, 2023
High-level batteries-included neural network training library for Pytorch

Pywick High-Level Training framework for Pytorch Pywick is a high-level Pytorch training framework that aims to get you up and running quickly with st

382 Dec 06, 2022
TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and

Kaiyu Yue 275 Nov 22, 2022
High-fidelity performance metrics for generative models in PyTorch

High-fidelity performance metrics for generative models in PyTorch

Vikram Voleti 5 Oct 24, 2021
Fast Discounted Cumulative Sums in PyTorch

TODO: update this README! Fast Discounted Cumulative Sums in PyTorch This repository implements an efficient parallel algorithm for the computation of

Daniel Povey 7 Feb 17, 2022
Model summary in PyTorch similar to `model.summary()` in Keras

Keras style model.summary() in PyTorch Keras has a neat API to view the visualization of the model which is very helpful while debugging your network.

Shubham Chandel 3.7k Dec 29, 2022