A PyTorch implementation of "Graph Wavelet Neural Network" (ICLR 2019)

Overview

Graph Wavelet Neural Network

Arxiv codebeat badge repo sizebenedekrozemberczki

A PyTorch implementation of Graph Wavelet Neural Network (ICLR 2019).

Abstract

We present graph wavelet neural network (GWNN), a novel graph convolutional neural network (CNN), leveraging graph wavelet transform to address the shortcomings of previous spectral graph CNN methods that depend on graph Fourier transform. Different from graph Fourier transform, graph wavelet transform can be obtained via a fast algorithm without requiring matrix eigendecomposition with high computational cost. Moreover, graph wavelets are sparse and localized in vertex domain, offering high efficiency and good interpretability for graph convolution. The proposed GWNN significantly outperforms previous spectral graph CNNs in the task of graph-based semi-supervised classification on three benchmark datasets: Cora, Citeseer and Pubmed.

A reference Tensorflow implementation is accessible [here].

This repository provides an implementation of Graph Wavelet Neural Network as described in the paper:

Graph Wavelet Neural Network. Bingbing Xu, Huawei Shen, Qi Cao, Yunqi Qiu, Xueqi Cheng. ICLR, 2019. [Paper]


Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx          2.4
tqdm              4.28.1
numpy             1.15.4
pandas            0.23.4
texttable         1.5.0
scipy             1.1.0
argparse          1.1.0
torch             1.1.0
torch-scatter     1.4.0
torch-sparse      0.4.3
torch-cluster     1.4.5
torch-geometric   1.3.2
torchvision       0.3.0
scikit-learn      0.20.0
PyGSP             0.5.1

Datasets

The code takes the **edge list** of the graph in a csv file. Every row indicates an edge between two nodes separated by a comma. The first row is a header. Nodes should be indexed starting with 0. A sample graph for `Cora` is included in the `input/` directory. In addition to the edgelist there is a JSON file with the sparse features and a csv with the target variable.

The **feature matrix** is a sparse binary one it is stored as a json. Nodes are keys of the json and feature indices are the values. For each node feature column ids are stored as elements of a list. The feature matrix is structured as:

{ 0: [0, 1, 38, 1968, 2000, 52727],
  1: [10000, 20, 3],
  2: [],
  ...
  n: [2018, 10000]}

The **target vector** is a csv with two columns and headers, the first contains the node identifiers the second the targets. This csv is sorted by node identifiers and the target column contains the class meberships indexed from zero.

NODE ID Target
0 3
1 1
2 0
3 1
... ...
n 3

Options

Training the model is handled by the src/main.py script which provides the following command line arguments.

Input and output options

  --edge-path        STR   Input graph path.   Default is `input/cora_edges.csv`.
  --features-path    STR   Features path.      Default is `input/cora_features.json`.
  --target-path      STR   Target path.        Default is `input/cora_target.csv`.
  --log-path         STR   Log path.           Default is `logs/cora_logs.json`.

Model options

  --epochs                INT       Number of Adam epochs.         Default is 200.
  --learning-rate         FLOAT     Number of training epochs.     Default is 0.01.
  --weight-decay          FLOAT     Weight decay.                  Default is 5*10**-4.
  --filters               INT       Number of filters.             Default is 16.
  --dropout               FLOAT     Dropout probability.           Default is 0.5.
  --test-size             FLOAT     Test set ratio.                Default is 0.2.
  --seed                  INT       Random seeds.                  Default is 42.
  --approximation-order   INT       Chebyshev polynomial order.    Default is 3.
  --tolerance             FLOAT     Wavelet coefficient limit.     Default is 10**-4.
  --scale                 FLOAT     Heat kernel scale.             Default is 1.0.

Examples

The following commands learn the weights of a graph wavelet neural network and saves the logs. The first example trains a graph wavelet neural network on the default dataset with standard hyperparameter settings. Saving the logs at the default path.

python src/main.py

Training a model with more filters in the first layer.

python src/main.py --filters 32

Approximationg the wavelets with polynomials that have an order of 5.

python src/main.py --approximation-order 5

License


Comments
  • what's the meanning of the

    what's the meanning of the "feature matrix"?

    Hello author, sorry about a stupid question. But the Cora dataset has Cora.cites corresponding your cora_edges.csv, and Cora.content's paper index and paper category for your cora_target.csv, so I don't understand the meanning of your cora_features.json . In the beginning, I just think it's an adjacency matrix of all nodes(paper index), however, the content are inconsistent. Such as ,in cora_edges.csv it's as the picture as follw: image and in cora_features.json it's : image So I am confused , and hope for your answer. Thank you very much.

    opened by CindyTing 7
  • How can l use this code for graph classification ?

    How can l use this code for graph classification ?

    Hi @benedekrozemberczki ,

    Let me first thank you for this promising work.

    I would like to apply your GWNN to graph classification problems rather than nodes classification.

    Do you have any extension for that ?

    Thank you

    opened by Benjiou 4
  • the kernel

    the kernel

    Hi, author, There was a variable in the code called diagnoal_weight_filter 屏幕截图 2021-01-16 204442 I think the variable should change in the trainning time,but it never changed when I debugging. It's so confusing. And I wonder if the variable conduct the same role as the diagnoal_weight_filer in the tensorflow implementation will change.

    opened by maxmit233 3
  • Fatal Python error: Segmentation fault

    Fatal Python error: Segmentation fault

    hi, author. These days i've been watching the program. But when I run on this code, I find an error happened during the time. Can you give me some suggestions?

    image

    image

    opened by Evelyn-coder 2
  • something about wavelet basis

    something about wavelet basis

    Hello~, Thank you for your paper. when I read the paper, I think about what is the connection between wavelet basis and Fourier basis, can you give me some tips?

    opened by ICDI0906 1
  • RuntimeError: the derivative for 'index' is not implemented

    RuntimeError: the derivative for 'index' is not implemented

    Hello, I was running the example and got this error.

    python src/main.py
    +---------------------+----------------------------+
    |      Parameter      |           Value            |
    +=====================+============================+
    | Approximation order | 20                         |
    +---------------------+----------------------------+
    | Dropout             | 0.500                      |
    +---------------------+----------------------------+
    | Edge path           | ./input/cora_edges.csv     |
    +---------------------+----------------------------+
    | Epochs              | 300                        |
    +---------------------+----------------------------+
    | Features path       | ./input/cora_features.json |
    +---------------------+----------------------------+
    | Filters             | 16                         |
    +---------------------+----------------------------+
    | Learning rate       | 0.001                      |
    +---------------------+----------------------------+
    | Log path            | ./logs/cora_logs.json      |
    +---------------------+----------------------------+
    | Scale               | 1                          |
    +---------------------+----------------------------+
    | Seed                | 42                         |
    +---------------------+----------------------------+
    | Target path         | ./input/cora_target.csv    |
    +---------------------+----------------------------+
    | Test size           | 0.200                      |
    +---------------------+----------------------------+
    | Tolerance           | 0.000                      |
    +---------------------+----------------------------+
    | Weight decay        | 0.001                      |
    +---------------------+----------------------------+
    
    Wavelet calculation and sparsification started.
    
    100%|███████████████████████████████████████████████████████████████████████████████████| 2708/2708 [00:11<00:00, 237.23it/s]
    100%|███████████████████████████████████████████████████████████████████████████████████| 2708/2708 [00:11<00:00, 228.91it/s]
    
    Normalizing the sparsified wavelets.
    
    Density of wavelets: 0.2%.
    Density of inverse wavelets: 0.04%.
    
    Training.
    
    Loss:   0%|                                                                                          | 0/300 [00:00<?, ?it/s]Traceback (most recent call last):
      File "src/main.py", line 24, in <module>
        main()
      File "src/main.py", line 18, in main
        trainer.fit()
      File "/home/paperspace/Thesis/GraphWaveletNeuralNetwork/src/gwnn.py", line 131, in fit
        prediction = self.model(self.phi_indices, self.phi_values , self.phi_inverse_indices, self.phi_inverse_values, self.feature_indices, self.feature_values)
      File "/home/paperspace/miniconda2/envs/thesis/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/paperspace/Thesis/GraphWaveletNeuralNetwork/src/gwnn.py", line 44, in forward
        deep_features_1 = self.convolution_1(phi_indices, phi_values, phi_inverse_indices, phi_inverse_values, feature_indices, feature_values, self.args.dropout)
      File "/home/paperspace/miniconda2/envs/thesis/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/paperspace/Thesis/GraphWaveletNeuralNetwork/src/gwnn_layer.py", line 55, in forward
        localized_features = spmm(phi_product_indices, phi_product_values, self.ncount, filtered_features)
      File "/home/paperspace/miniconda2/envs/thesis/lib/python3.6/site-packages/torch_sparse/spmm.py", line 21, in spmm
        out = scatter_add(out, row, dim=0, dim_size=m)
      File "/home/paperspace/miniconda2/envs/thesis/lib/python3.6/site-packages/torch_scatter/add.py", line 73, in scatter_add
        return out.scatter_add_(dim, index, src)
    RuntimeError: the derivative for 'index' is not implemented
    
    opened by youjinChung 1
Releases(v_00001)
Owner
Benedek Rozemberczki
Machine Learning Engineer at AstraZeneca | PhD from The University of Edinburgh.
Benedek Rozemberczki
Pocsploit is a lightweight, flexible and novel open source poc verification framework

Pocsploit is a lightweight, flexible and novel open source poc verification framework

cckuailong 208 Dec 24, 2022
Python package provinding tools for artistic interactive applications using AI

Documentation redrawing Python package provinding tools for artistic interactive applications using AI Created by ReDrawing Campinas team for the Open

ReDrawing Campinas 1 Sep 30, 2021
Unofficial PyTorch implementation of "RTM3D: Real-time Monocular 3D Detection from Object Keypoints for Autonomous Driving" (ECCV 2020)

RTM3D-PyTorch The PyTorch Implementation of the paper: RTM3D: Real-time Monocular 3D Detection from Object Keypoints for Autonomous Driving (ECCV 2020

Nguyen Mau Dzung 271 Nov 29, 2022
Utilities to bridge Canvas-generated course rosters with GitLab's API.

gitlab-canvas-utils A collection of scripts originally written for CSE 13S. Oversees everything from GitLab course group creation, student repository

Eugene Chou 5 Jun 08, 2022
Official Implementation of "Learning Disentangled Behavior Embeddings"

DBE: Disentangled-Behavior-Embedding Official implementation of Learning Disentangled Behavior Embeddings (NeurIPS 2021). Environment requirement The

Mishne Lab 12 Sep 28, 2022
Deep Learning (with PyTorch)

Deep Learning (with PyTorch) This notebook repository now has a companion website, where all the course material can be found in video and textual for

Alfredo Canziani 6.2k Jan 07, 2023
Code for Dual Contrastive Learning for Unsupervised Image-to-Image Translation, NTIRE, CVPRW 2021.

arXiv Dual Contrastive Learning Adversarial Generative Networks (DCLGAN) We provide our PyTorch implementation of DCLGAN, which is a simple yet powerf

119 Dec 04, 2022
MobileNetV1-V2,MobileNeXt,GhostNet,AdderNet,ShuffleNetV1-V2,Mobile+ViT etc.

MobileNetV1-V2,MobileNeXt,GhostNet,AdderNet,ShuffleNetV1-V2,Mobile+ViT etc. ⭐⭐⭐⭐⭐

568 Jan 04, 2023
Mini Software that give reminder to drink water as per your weight.

Water Notification Desktop Python The Mini Software built in Python (tkinter) that will remind you to drink water on specific time span based on your

Om Jogani 5 Dec 16, 2022
[ICML 2021] "Graph Contrastive Learning Automated" by Yuning You, Tianlong Chen, Yang Shen, Zhangyang Wang

Graph Contrastive Learning Automated PyTorch implementation for Graph Contrastive Learning Automated [talk] [poster] [appendix] Yuning You, Tianlong C

Shen Lab at Texas A&M University 80 Nov 23, 2022
i-RevNet Pytorch Code

i-RevNet: Deep Invertible Networks Pytorch implementation of i-RevNets. i-RevNets define a family of fully invertible deep networks, built from a succ

Jörn Jacobsen 378 Dec 06, 2022
The Balloon Learning Environment - flying stratospheric balloons with deep reinforcement learning.

Balloon Learning Environment Docs The Balloon Learning Environment (BLE) is a simulator for stratospheric balloons. It is designed as a benchmark envi

Google 87 Dec 25, 2022
Low-dose Digital Mammography with Deep Learning

Impact of loss functions on the performance of a deep neural network designed to restore low-dose digital mammography ====== This repository contains

WANG-AXIS 6 Dec 13, 2022
A Dynamic Residual Self-Attention Network for Lightweight Single Image Super-Resolution

DRSAN A Dynamic Residual Self-Attention Network for Lightweight Single Image Super-Resolution Karam Park, Jae Woong Soh, and Nam Ik Cho Environments U

4 May 10, 2022
DetCo: Unsupervised Contrastive Learning for Object Detection

DetCo: Unsupervised Contrastive Learning for Object Detection arxiv link News Sparse RCNN+DetCo improves from 45.0 AP to 46.5 AP(+1.5) with 3x+ms trai

Enze Xie 234 Dec 18, 2022
StarGAN2 for practice

StarGAN2 for practice This version of StarGAN2 (coined as 'Post-modern Style Transfer') is intended mostly for fellow artists, who rarely look at scie

vadim epstein 87 Sep 24, 2022
A robust camera and Lidar fusion based velocity estimator to undistort the pointcloud.

Lidar with Velocity A robust camera and Lidar fusion based velocity estimator to undistort the pointcloud. related paper: Lidar with Velocity : Motion

ISEE Research Group 164 Dec 30, 2022
The codes and related files to reproduce the results for Image Similarity Challenge Track 1.

ISC-Track1-Submission The codes and related files to reproduce the results for Image Similarity Challenge Track 1. Required dependencies To begin with

Wenhao Wang 115 Jan 02, 2023
Code to reproduce the results for Compositional Attention

Compositional-Attention This repository contains the official implementation for the paper Compositional Attention: Disentangling Search and Retrieval

Sarthak Mittal 58 Nov 30, 2022
An Implementation of Transformer in Transformer in TensorFlow for image classification, attention inside local patches

Transformer-in-Transformer An Implementation of the Transformer in Transformer paper by Han et al. for image classification, attention inside local pa

Rishit Dagli 40 Jul 25, 2022